Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 121 additions & 41 deletions python/packages/core/agent_framework/_workflows/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import functools
import inspect
import logging
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from typing import Any, TypeVar, overload

from ..observability import create_processing_span
from ._events import (
Expand Down Expand Up @@ -326,34 +327,37 @@ def _discover_handlers(self) -> None:
"""Discover message handlers in the executor class."""
# Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes
for attr_name in dir(self.__class__):
# Narrow the exception scope - only catch AttributeError when accessing the attribute
try:
attr = getattr(self.__class__, attr_name)
# Discover @handler methods
if callable(attr) and hasattr(attr, "_handler_spec"):
handler_spec = attr._handler_spec # type: ignore
message_type = handler_spec["message_type"]

# Keep full generic types for handler registration to avoid conflicts
if self._handlers.get(message_type) is not None:
raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}")

# Get the bound method
bound_method = getattr(self, attr_name)
self._handlers[message_type] = bound_method

# Add to unified handler specs list
self._handler_specs.append({
"name": handler_spec["name"],
"message_type": message_type,
"output_types": handler_spec.get("output_types", []),
"workflow_output_types": handler_spec.get("workflow_output_types", []),
"ctx_annotation": handler_spec.get("ctx_annotation"),
"source": "class_method", # Distinguish from instance handlers if needed
})
except AttributeError:
# Skip attributes that may not be accessible
# Skip attributes that may not be accessible (e.g., dynamic descriptors)
logger.debug(f"Could not access attribute {attr_name} on {self.__class__.__name__}")
continue

# Discover @handler methods - let AttributeError propagate for malformed handler specs
if callable(attr) and hasattr(attr, "_handler_spec"):
handler_spec = attr._handler_spec # type: ignore
message_type = handler_spec["message_type"]

# Keep full generic types for handler registration to avoid conflicts
if self._handlers.get(message_type) is not None:
raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}")

# Get the bound method
bound_method = getattr(self, attr_name)
self._handlers[message_type] = bound_method

# Add to unified handler specs list
self._handler_specs.append({
"name": handler_spec["name"],
"message_type": message_type,
"output_types": handler_spec.get("output_types", []),
"workflow_output_types": handler_spec.get("workflow_output_types", []),
"ctx_annotation": handler_spec.get("ctx_annotation"),
"source": "class_method", # Distinguish from instance handlers if needed
})

def can_handle(self, message: Message) -> bool:
"""Check if the executor can handle a given message type.

Expand Down Expand Up @@ -529,33 +533,97 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]")


@overload
def handler(
func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]],
) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]:
) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: ...


@overload
def handler(
*,
input_type: type | types.UnionType | str | None = None,
output_type: type | types.UnionType | str | None = None,
) -> Callable[
[Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]],
Callable[[ExecutorT, Any, ContextT], Awaitable[Any]],
]: ...


def handler(
func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | None = None,
*,
input_type: type | types.UnionType | str | None = None,
output_type: type | types.UnionType | str | None = None,
) -> (
Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]
| Callable[
[Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]],
Callable[[ExecutorT, Any, ContextT], Awaitable[Any]],
]
):
"""Decorator to register a handler for an executor.

Args:
func: The function to decorate. Can be None when used without parameters.
func: The function to decorate. Can be None when used with parameters.
input_type: Optional explicit input type(s) for this handler. Supports union types
(e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``).
When provided, takes precedence over introspection from the function's message
parameter annotation.
output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``.
Supports union types (e.g., ``str | int``) and string forward references.
When provided, takes precedence over introspection from the ``WorkflowContext``
generic parameters.

Returns:
The decorated function with handler metadata.

Example:
@handler
async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None:
...
.. code-block:: python

# Using introspection (existing behavior)
@handler
async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: ...


@handler
async def handle_data(self, message: dict, ctx: WorkflowContext[str | int]) -> None:
...
# Using explicit types (takes precedence over introspection)
@handler(input_type=str | int, output_type=bool)
async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ...


# Using string forward references
@handler(input_type="MyCustomType | int", output_type="ResponseType")
async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ...
"""
from ._typing_utils import normalize_type_to_list, resolve_type_annotation

def decorator(
func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]],
) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]:
# Resolve string forward references using the function's globals
resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None
resolved_output_type = (
resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None
)

# Extract the message type and validate using unified validation
message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = (
_validate_handler_signature(func)
introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = (
_validate_handler_signature(func, skip_message_annotation=resolved_input_type is not None)
)

# Use explicit types if provided, otherwise fall back to introspection
message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type

# Validate that we have a message type - this should never happen if signature
# validation passed, but provides a clear error if type information is missing
if message_type is None:
raise ValueError(
f"Handler {func.__name__} requires either a message parameter type annotation "
"or an explicit input_type parameter"
)

final_output_types = (
normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types
)

# Get signature for preservation
Expand All @@ -574,29 +642,41 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any:
"name": func.__name__,
"message_type": message_type,
# Keep output_types and workflow_output_types in spec for validators
"output_types": inferred_output_types,
"output_types": final_output_types,
"workflow_output_types": inferred_workflow_output_types,
"ctx_annotation": ctx_annotation,
}

return wrapper

return decorator(func)
# Handle both @handler and @handler(...) usage patterns
if func is not None:
# Called as @handler without parentheses
return decorator(func)
# Called as @handler(...) with parentheses
return decorator


# endregion: Handler Decorator

# region Handler Validation


def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]:
def _validate_handler_signature(
func: Callable[..., Any],
*,
skip_message_annotation: bool = False,
) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]:
"""Validate function signature for executor functions.

Args:
func: The function to validate
skip_message_annotation: If True, skip validation that message parameter has a type
annotation. Used when input_type is explicitly provided to the @handler decorator.

Returns:
Tuple of (message_type, ctx_annotation, output_types, workflow_output_types)
Tuple of (message_type, ctx_annotation, output_types, workflow_output_types).
message_type may be None if skip_message_annotation is True and no annotation exists.

Raises:
ValueError: If the function signature is invalid
Expand All @@ -609,9 +689,9 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li
if len(params) != expected_counts:
raise ValueError(f"Handler {func.__name__} must have {param_description}. Got {len(params)} parameters.")

# Check message parameter has type annotation
# Check message parameter has type annotation (unless skipped)
message_param = params[1]
if message_param.annotation == inspect.Parameter.empty:
if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty:
raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter")

# Validate ctx parameter is WorkflowContext and extract type args
Expand All @@ -620,7 +700,7 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li
ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler"
)

message_type = message_param.annotation
message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None
ctx_annotation = ctx_param.annotation

return message_type, ctx_annotation, output_types, workflow_output_types
Expand Down
Loading
Loading