diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 49f3dafd06..31ffc7fd6f 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -5,8 +5,9 @@ import functools import inspect import logging +import types from collections.abc import Awaitable, Callable -from typing import Any, TypeVar +from typing import Any, TypeVar, overload from ..observability import create_processing_span from ._events import ( @@ -326,34 +327,37 @@ def _discover_handlers(self) -> None: """Discover message handlers in the executor class.""" # Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes for attr_name in dir(self.__class__): + # Narrow the exception scope - only catch AttributeError when accessing the attribute try: attr = getattr(self.__class__, attr_name) - # Discover @handler methods - if callable(attr) and hasattr(attr, "_handler_spec"): - handler_spec = attr._handler_spec # type: ignore - message_type = handler_spec["message_type"] - - # Keep full generic types for handler registration to avoid conflicts - if self._handlers.get(message_type) is not None: - raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}") - - # Get the bound method - bound_method = getattr(self, attr_name) - self._handlers[message_type] = bound_method - - # Add to unified handler specs list - self._handler_specs.append({ - "name": handler_spec["name"], - "message_type": message_type, - "output_types": handler_spec.get("output_types", []), - "workflow_output_types": handler_spec.get("workflow_output_types", []), - "ctx_annotation": handler_spec.get("ctx_annotation"), - "source": "class_method", # Distinguish from instance handlers if needed - }) except AttributeError: - # Skip attributes that may not be accessible + # Skip attributes that may not be accessible (e.g., dynamic descriptors) + logger.debug(f"Could not access attribute {attr_name} on {self.__class__.__name__}") continue + # Discover @handler methods - let AttributeError propagate for malformed handler specs + if callable(attr) and hasattr(attr, "_handler_spec"): + handler_spec = attr._handler_spec # type: ignore + message_type = handler_spec["message_type"] + + # Keep full generic types for handler registration to avoid conflicts + if self._handlers.get(message_type) is not None: + raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}") + + # Get the bound method + bound_method = getattr(self, attr_name) + self._handlers[message_type] = bound_method + + # Add to unified handler specs list + self._handler_specs.append({ + "name": handler_spec["name"], + "message_type": message_type, + "output_types": handler_spec.get("output_types", []), + "workflow_output_types": handler_spec.get("workflow_output_types", []), + "ctx_annotation": handler_spec.get("ctx_annotation"), + "source": "class_method", # Distinguish from instance handlers if needed + }) + def can_handle(self, message: Message) -> bool: """Check if the executor can handle a given message type. @@ -529,33 +533,97 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]") +@overload def handler( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], -) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: +) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: ... + + +@overload +def handler( + *, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, +) -> Callable[ + [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], +]: ... + + +def handler( + func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] | None = None, + *, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, +) -> ( + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]] + | Callable[ + [Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]], + Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], + ] +): """Decorator to register a handler for an executor. Args: - func: The function to decorate. Can be None when used without parameters. + func: The function to decorate. Can be None when used with parameters. + input_type: Optional explicit input type(s) for this handler. Supports union types + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Returns: The decorated function with handler metadata. Example: - @handler - async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: - ... + .. code-block:: python + + # Using introspection (existing behavior) + @handler + async def handle_string(self, message: str, ctx: WorkflowContext[str]) -> None: ... + - @handler - async def handle_data(self, message: dict, ctx: WorkflowContext[str | int]) -> None: - ... + # Using explicit types (takes precedence over introspection) + @handler(input_type=str | int, output_type=bool) + async def handle_data(self, message: Any, ctx: WorkflowContext) -> None: ... + + + # Using string forward references + @handler(input_type="MyCustomType | int", output_type="ResponseType") + async def handle_custom(self, message: Any, ctx: WorkflowContext) -> None: ... """ + from ._typing_utils import normalize_type_to_list, resolve_type_annotation def decorator( func: Callable[[ExecutorT, Any, ContextT], Awaitable[Any]], ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: + # Resolve string forward references using the function's globals + resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None + resolved_output_type = ( + resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None + ) + # Extract the message type and validate using unified validation - message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - _validate_handler_signature(func) + introspected_message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( + _validate_handler_signature(func, skip_message_annotation=resolved_input_type is not None) + ) + + # Use explicit types if provided, otherwise fall back to introspection + message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type + + # Validate that we have a message type - this should never happen if signature + # validation passed, but provides a clear error if type information is missing + if message_type is None: + raise ValueError( + f"Handler {func.__name__} requires either a message parameter type annotation " + "or an explicit input_type parameter" + ) + + final_output_types = ( + normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types ) # Get signature for preservation @@ -574,14 +642,19 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: "name": func.__name__, "message_type": message_type, # Keep output_types and workflow_output_types in spec for validators - "output_types": inferred_output_types, + "output_types": final_output_types, "workflow_output_types": inferred_workflow_output_types, "ctx_annotation": ctx_annotation, } return wrapper - return decorator(func) + # Handle both @handler and @handler(...) usage patterns + if func is not None: + # Called as @handler without parentheses + return decorator(func) + # Called as @handler(...) with parentheses + return decorator # endregion: Handler Decorator @@ -589,14 +662,21 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: # region Handler Validation -def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: +def _validate_handler_signature( + func: Callable[..., Any], + *, + skip_message_annotation: bool = False, +) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: """Validate function signature for executor functions. Args: func: The function to validate + skip_message_annotation: If True, skip validation that message parameter has a type + annotation. Used when input_type is explicitly provided to the @handler decorator. Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types). + message_type may be None if skip_message_annotation is True and no annotation exists. Raises: ValueError: If the function signature is invalid @@ -609,9 +689,9 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li if len(params) != expected_counts: raise ValueError(f"Handler {func.__name__} must have {param_description}. Got {len(params)} parameters.") - # Check message parameter has type annotation + # Check message parameter has type annotation (unless skipped) message_param = params[1] - if message_param.annotation == inspect.Parameter.empty: + if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter") # Validate ctx parameter is WorkflowContext and extract type args @@ -620,7 +700,7 @@ def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, li ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler" ) - message_type = message_param.annotation + message_type = message_param.annotation if message_param.annotation != inspect.Parameter.empty else None ctx_annotation = ctx_param.annotation return message_type, ctx_annotation, output_types, workflow_output_types diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index d7b68c10fd..54108c14d0 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -18,11 +18,13 @@ import asyncio import inspect import sys +import types import typing from collections.abc import Awaitable, Callable from typing import Any from ._executor import Executor +from ._typing_utils import normalize_type_to_list, resolve_type_annotation from ._workflow_context import WorkflowContext, validate_workflow_context_annotation if sys.version_info >= (3, 11): @@ -41,12 +43,27 @@ class FunctionExecutor(Executor): blocking the event loop. """ - def __init__(self, func: Callable[..., Any], id: str | None = None): + def __init__( + self, + func: Callable[..., Any], + id: str | None = None, + *, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, + ): """Initialize the FunctionExecutor with a user-defined function. Args: func: The function to wrap as an executor (can be sync or async) id: Optional executor ID. If None, uses the function name. + input_type: Optional explicit input type(s) for this executor. Supports union types + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Raises: ValueError: If func is a staticmethod or classmethod (use @handler on instance methods instead) @@ -60,8 +77,29 @@ def __init__(self, func: Callable[..., Any], id: str | None = None): f"or create an Executor subclass and use @handler on instance methods instead." ) + # Resolve string forward references using the function's globals + resolved_input_type = resolve_type_annotation(input_type, func.__globals__) if input_type is not None else None + resolved_output_type = ( + resolve_type_annotation(output_type, func.__globals__) if output_type is not None else None + ) + # Validate function signature and extract types - message_type, ctx_annotation, output_types, workflow_output_types = _validate_function_signature(func) + introspected_message_type, ctx_annotation, inferred_output_types, workflow_output_types = ( + _validate_function_signature(func, skip_message_annotation=resolved_input_type is not None) + ) + + # Use explicit types if provided, otherwise fall back to introspection + message_type = resolved_input_type if resolved_input_type is not None else introspected_message_type + output_types = ( + normalize_type_to_list(resolved_output_type) if resolved_output_type is not None else inferred_output_types + ) + + # Validate that we have a message type - provides a clear error if type information is missing + if message_type is None: + raise ValueError( + f"Function {func.__name__} requires either a message parameter type annotation " + "or an explicit input_type parameter" + ) # Store the original function self._original_func = func @@ -127,11 +165,20 @@ def executor(func: Callable[..., Any]) -> FunctionExecutor: ... @overload -def executor(*, id: str | None = None) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... +def executor( + *, + id: str | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, +) -> Callable[[Callable[..., Any]], FunctionExecutor]: ... def executor( - func: Callable[..., Any] | None = None, *, id: str | None = None + func: Callable[..., Any] | None = None, + *, + id: str | None = None, + input_type: type | types.UnionType | str | None = None, + output_type: type | types.UnionType | str | None = None, ) -> Callable[[Callable[..., Any]], FunctionExecutor] | FunctionExecutor: """Decorator that converts a standalone function into a FunctionExecutor instance. @@ -162,6 +209,17 @@ def process_data(data: str): return data.upper() + # Using explicit types (takes precedence over introspection): + @executor(id="my_executor", input_type=str | int, output_type=bool) + async def process(message: Any, ctx: WorkflowContext): + await ctx.send_message(True) + + + # Using string forward references: + @executor(input_type="MyCustomType | int", output_type="ResponseType") + async def process(message: Any, ctx: WorkflowContext): ... + + # For class-based executors, use @handler instead: class MyExecutor(Executor): def __init__(self): @@ -174,6 +232,14 @@ async def process(self, data: str, ctx: WorkflowContext[str]): Args: func: The function to decorate (when used without parentheses) id: Optional custom ID for the executor. If None, uses the function name. + input_type: Optional explicit input type(s) for this executor. Supports union types + (e.g., ``str | int``) and string forward references (e.g., ``"MyType | int"``). + When provided, takes precedence over introspection from the function's message + parameter annotation. + output_type: Optional explicit output type(s) that can be sent via ``ctx.send_message()``. + Supports union types (e.g., ``str | int``) and string forward references. + When provided, takes precedence over introspection from the ``WorkflowContext`` + generic parameters. Returns: A FunctionExecutor instance that can be wired into a Workflow. @@ -183,7 +249,7 @@ async def process(self, data: str, ctx: WorkflowContext[str]): """ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: - return FunctionExecutor(func, id=id) + return FunctionExecutor(func, id=id, input_type=input_type, output_type=output_type) # If func is provided, this means @executor was used without parentheses if func is not None: @@ -198,14 +264,21 @@ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: # region Function Validation -def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: +def _validate_function_signature( + func: Callable[..., Any], + *, + skip_message_annotation: bool = False, +) -> tuple[type | None, Any, list[type[Any]], list[type[Any]]]: """Validate function signature for executor functions. Args: func: The function to validate + skip_message_annotation: If True, skip validation that message parameter has a type + annotation. Used when input_type is explicitly provided to the @executor decorator. Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types). + message_type may be None if skip_message_annotation is True and no annotation exists. Raises: ValueError: If the function signature is invalid @@ -220,13 +293,15 @@ def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, l f"Function instance {func.__name__} must have {param_description}. Got {len(params)} parameters." ) - # Check message parameter has type annotation + # Check message parameter has type annotation (unless skipped) message_param = params[0] - if message_param.annotation == inspect.Parameter.empty: + if not skip_message_annotation and message_param.annotation == inspect.Parameter.empty: raise ValueError(f"Function instance {func.__name__} must have a type annotation for the message parameter") type_hints = typing.get_type_hints(func) message_type = type_hints.get(message_param.name, message_param.annotation) + if message_type == inspect.Parameter.empty: + message_type = None # Check if there's a context parameter if len(params) == 2: diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 5619fb9bf3..d0e9490a24 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -1,14 +1,100 @@ # Copyright (c) Microsoft. All rights reserved. -import logging from types import UnionType from typing import Any, TypeVar, Union, cast, get_args, get_origin -logger = logging.getLogger(__name__) +from agent_framework import get_logger + +logger = get_logger("agent_framework._workflows._typing_utils") T = TypeVar("T") +def resolve_type_annotation( + type_annotation: type[Any] | UnionType | str | None, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, +) -> type[Any] | UnionType | None: + """Resolve a type annotation, including string forward references. + + Args: + type_annotation: A type, union type, string forward reference, or None + globalns: Global namespace for resolving forward references (typically func.__globals__) + localns: Local namespace for resolving forward references + + Returns: + The resolved type annotation. For string annotations, evaluates them in the + provided namespace. Returns None if type_annotation is None. + + Raises: + NameError: If a forward reference cannot be resolved in the provided namespaces + SyntaxError: If a string annotation contains invalid Python syntax + + Note: + This function uses eval() to resolve string type annotations. This is the same + approach used by Python's typing.get_type_hints() and typing.ForwardRef internally. + Security is managed by: (1) strings come from decorator parameters in source code, + not runtime user input, and (2) the eval namespace is restricted to the function's + module globals plus Union/Optional from typing. + + Examples: + - resolve_type_annotation(str) -> str + - resolve_type_annotation("str | int", {"str": str, "int": int}) -> str | int + - resolve_type_annotation("MyClass", {"MyClass": MyClass}) -> MyClass + """ + if type_annotation is None: + return None + + if isinstance(type_annotation, str): + # Resolve string forward reference by evaluating it. + # This uses eval() which is the same approach as Python's typing.get_type_hints() + # and typing.ForwardRef._evaluate(). The namespace is restricted to the function's + # globals plus typing constructs, and input comes from developer source code. + eval_globalns = globalns.copy() if globalns else {} + eval_globalns.setdefault("Union", Union) + eval_globalns.setdefault("Optional", __import__("typing").Optional) + + try: + return eval(type_annotation, eval_globalns, localns) # noqa: S307 # nosec B307 + except NameError as e: + raise NameError( + f"Could not resolve type annotation '{type_annotation}'. " + f"Make sure the type is defined or imported. Original error: {e}" + ) from e + + return type_annotation + + +def normalize_type_to_list(type_annotation: type[Any] | UnionType | None) -> list[type[Any]]: + """Normalize a type annotation (possibly a union) to a list of concrete types. + + Args: + type_annotation: A type, union type (using | or Union[]), or None + + Returns: + A list of types. For union types, returns all members. + For None, returns an empty list. + For Optional[T] (Union[T, None]), returns [T, type(None)]. + + Examples: + - normalize_type_to_list(str) -> [str] + - normalize_type_to_list(str | int) -> [str, int] + - normalize_type_to_list(Union[str, int]) -> [str, int] + - normalize_type_to_list(None) -> [] + """ + if type_annotation is None: + return [] + + origin = get_origin(type_annotation) + + # Handle Union types (str | int or Union[str, int]) + if origin is Union or origin is UnionType: + return list(get_args(type_annotation)) + + # Single type + return [type_annotation] + + def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool: """Check if the data is an instance of the target type. diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index a812f6dae6..b34015d9b5 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from dataclasses import dataclass + import pytest from agent_framework import ( @@ -16,6 +18,27 @@ ) +# Module-level types for string forward reference tests +@dataclass +class ForwardRefMessage: + content: str + + +@dataclass +class ForwardRefTypeA: + value: str + + +@dataclass +class ForwardRefTypeB: + value: int + + +@dataclass +class ForwardRefResponse: + result: str + + def test_executor_without_id(): """Test that an executor without an ID raises an error when trying to run.""" @@ -538,3 +561,245 @@ async def mutator(messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMes f"{[m.text for m in mutator_invoked.data]}" ) assert mutator_invoked.data[0].text == "hello" + + +# region: Tests for @handler decorator with explicit input_type and output_type + + +class TestHandlerExplicitTypes: + """Test suite for @handler decorator with explicit input_type and output_type parameters.""" + + def test_handler_with_explicit_input_type(self): + """Test that explicit input_type takes precedence over introspection.""" + from typing import Any + + class ExplicitInputExecutor(Executor): + @handler(input_type=str) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = ExplicitInputExecutor(id="explicit_input") + + # Handler should be registered for str (explicit), not Any (introspected) + assert str in exec_instance._handlers + assert len(exec_instance._handlers) == 1 + + # Can handle str messages + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + # Cannot handle int messages (since explicit type is str) + assert not exec_instance.can_handle(Message(data=42, source_id="mock")) + + def test_handler_with_explicit_output_type(self): + """Test that explicit output_type takes precedence over introspection.""" + + class ExplicitOutputExecutor(Executor): + @handler(output_type=int) + async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: + pass + + exec_instance = ExplicitOutputExecutor(id="explicit_output") + + # Handler spec should have int as output type (explicit), not str (introspected) + handler_func = exec_instance._handlers[str] + assert handler_func._handler_spec["output_types"] == [int] + + # Executor output_types property should reflect explicit type + assert int in exec_instance.output_types + assert str not in exec_instance.output_types + + def test_handler_with_explicit_input_and_output_types(self): + """Test that both explicit input_type and output_type work together.""" + from typing import Any + + class ExplicitBothExecutor(Executor): + @handler(input_type=dict, output_type=list) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = ExplicitBothExecutor(id="explicit_both") + + # Handler should be registered for dict (explicit input type) + assert dict in exec_instance._handlers + assert len(exec_instance._handlers) == 1 + + # Output type should be list (explicit) + handler_func = exec_instance._handlers[dict] + assert handler_func._handler_spec["output_types"] == [list] + + # Verify can_handle + assert exec_instance.can_handle(Message(data={"key": "value"}, source_id="mock")) + assert not exec_instance.can_handle(Message(data="string", source_id="mock")) + + def test_handler_with_explicit_union_input_type(self): + """Test that explicit union input_type is handled correctly.""" + from typing import Any + + class UnionInputExecutor(Executor): + @handler(input_type=str | int) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = UnionInputExecutor(id="union_input") + + # Handler should be registered for the union type + # The union type itself is stored as the key + assert len(exec_instance._handlers) == 1 + + # Can handle both str and int messages + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + assert exec_instance.can_handle(Message(data=42, source_id="mock")) + # Cannot handle float + assert not exec_instance.can_handle(Message(data=3.14, source_id="mock")) + + def test_handler_with_explicit_union_output_type(self): + """Test that explicit union output_type is normalized to a list.""" + from typing import Any + + class UnionOutputExecutor(Executor): + @handler(output_type=str | int | bool) + async def handle(self, message: Any, ctx: WorkflowContext) -> None: + pass + + exec_instance = UnionOutputExecutor(id="union_output") + + # Output types should be a list with all union members + assert set(exec_instance.output_types) == {str, int, bool} + + def test_handler_explicit_types_precedence_over_introspection(self): + """Test that explicit types always take precedence over introspected types.""" + + class PrecedenceExecutor(Executor): + # Introspection would give: input=str, output=[int] + # Explicit gives: input=bytes, output=[float] + @handler(input_type=bytes, output_type=float) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_instance = PrecedenceExecutor(id="precedence") + + # Should use explicit input type (bytes), not introspected (str) + assert bytes in exec_instance._handlers + assert str not in exec_instance._handlers + + # Should use explicit output type (float), not introspected (int) + assert float in exec_instance.output_types + assert int not in exec_instance.output_types + + def test_handler_fallback_to_introspection_when_no_explicit_types(self): + """Test that introspection is used when no explicit types are provided.""" + + class IntrospectedExecutor(Executor): + @handler + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_instance = IntrospectedExecutor(id="introspected") + + # Should use introspected types + assert str in exec_instance._handlers + assert int in exec_instance.output_types + + def test_handler_partial_explicit_types(self): + """Test that partial explicit types work (only input_type or only output_type).""" + + # Only explicit input_type, introspect output_type + class OnlyInputExecutor(Executor): + @handler(input_type=bytes) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_input = OnlyInputExecutor(id="only_input") + assert bytes in exec_input._handlers # Explicit + assert int in exec_input.output_types # Introspected + + # Only explicit output_type, introspect input_type + class OnlyOutputExecutor(Executor): + @handler(output_type=float) + async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: + pass + + exec_output = OnlyOutputExecutor(id="only_output") + assert str in exec_output._handlers # Introspected + assert float in exec_output.output_types # Explicit + assert int not in exec_output.output_types # Not introspected when explicit provided + + def test_handler_explicit_input_type_allows_no_message_annotation(self): + """Test that explicit input_type allows handler without message type annotation.""" + + class NoAnnotationExecutor(Executor): + @handler(input_type=str) + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = NoAnnotationExecutor(id="no_annotation") + + # Should work with explicit input_type + assert str in exec_instance._handlers + assert exec_instance.can_handle(Message(data="hello", source_id="mock")) + + def test_handler_multiple_handlers_mixed_explicit_and_introspected(self): + """Test executor with multiple handlers, some with explicit types and some introspected.""" + + class MixedExecutor(Executor): + @handler(input_type=str, output_type=int) + async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + @handler + async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: + pass + + exec_instance = MixedExecutor(id="mixed") + + # Should have both handlers + assert len(exec_instance._handlers) == 2 + assert str in exec_instance._handlers # Explicit + assert float in exec_instance._handlers # Introspected + + # Should have both output types + assert int in exec_instance.output_types # Explicit + assert bool in exec_instance.output_types # Introspected + + def test_handler_with_string_forward_reference_input_type(self): + """Test that string forward references work for input_type.""" + + class StringRefExecutor(Executor): + @handler(input_type="ForwardRefMessage") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringRefExecutor(id="string_ref") + + # Should resolve the string to the actual type + assert ForwardRefMessage in exec_instance._handlers + assert exec_instance.can_handle(Message(data=ForwardRefMessage("hello"), source_id="mock")) + + def test_handler_with_string_forward_reference_union(self): + """Test that string forward references work with union types.""" + + class StringUnionExecutor(Executor): + @handler(input_type="ForwardRefTypeA | ForwardRefTypeB") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringUnionExecutor(id="string_union") + + # Should handle both types + assert exec_instance.can_handle(Message(data=ForwardRefTypeA("hello"), source_id="mock")) + assert exec_instance.can_handle(Message(data=ForwardRefTypeB(42), source_id="mock")) + + def test_handler_with_string_forward_reference_output_type(self): + """Test that string forward references work for output_type.""" + + class StringOutputExecutor(Executor): + @handler(input_type=str, output_type="ForwardRefResponse") + async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + exec_instance = StringOutputExecutor(id="string_output") + + # Should resolve the string output type + assert ForwardRefResponse in exec_instance.output_types + + +# endregion: Tests for @handler decorator with explicit input_type and output_type diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index a034f42a38..71d6cb34e2 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from dataclasses import dataclass from typing import Any import pytest @@ -14,6 +15,27 @@ ) +# Module-level types for string forward reference tests +@dataclass +class FuncExecForwardRefMessage: + content: str + + +@dataclass +class FuncExecForwardRefTypeA: + value: str + + +@dataclass +class FuncExecForwardRefTypeB: + value: int + + +@dataclass +class FuncExecForwardRefResponse: + result: str + + class TestFunctionExecutor: """Test suite for FunctionExecutor and @executor decorator.""" @@ -535,3 +557,235 @@ class C: async_static = static_wrapped assert asyncio.iscoroutinefunction(C.async_static) # Works via descriptor protocol + + +class TestExecutorExplicitTypes: + """Test suite for @executor decorator with explicit input_type and output_type parameters.""" + + def test_executor_with_explicit_input_type(self): + """Test that explicit input_type takes precedence over introspection.""" + + @executor(input_type=str) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for str (explicit) + assert str in process._handlers + assert len(process._handlers) == 1 + + # Can handle str messages + assert process.can_handle(Message(data="hello", source_id="mock")) + # Cannot handle int messages + assert not process.can_handle(Message(data=42, source_id="mock")) + + def test_executor_with_explicit_output_type(self): + """Test that explicit output_type takes precedence over introspection.""" + + @executor(output_type=int) + async def process(message: str, ctx: WorkflowContext[str]) -> None: + pass + + # Handler spec should have int as output type (explicit), not str (introspected) + spec = process._handler_specs[0] + assert spec["output_types"] == [int] + + # Executor output_types property should reflect explicit type + assert int in process.output_types + assert str not in process.output_types + + def test_executor_with_explicit_input_and_output_types(self): + """Test that both explicit input_type and output_type work together.""" + + @executor(id="explicit_both", input_type=dict, output_type=list) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for dict (explicit input type) + assert dict in process._handlers + assert len(process._handlers) == 1 + + # Output type should be list (explicit) + spec = process._handler_specs[0] + assert spec["output_types"] == [list] + + # Verify can_handle + assert process.can_handle(Message(data={"key": "value"}, source_id="mock")) + assert not process.can_handle(Message(data="string", source_id="mock")) + + def test_executor_with_explicit_union_input_type(self): + """Test that explicit union input_type is handled correctly.""" + + @executor(input_type=str | int) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Handler should be registered for the union type + assert len(process._handlers) == 1 + + # Can handle both str and int messages + assert process.can_handle(Message(data="hello", source_id="mock")) + assert process.can_handle(Message(data=42, source_id="mock")) + # Cannot handle float + assert not process.can_handle(Message(data=3.14, source_id="mock")) + + def test_executor_with_explicit_union_output_type(self): + """Test that explicit union output_type is normalized to a list.""" + + @executor(output_type=str | int | bool) + async def process(message: Any, ctx: WorkflowContext) -> None: + pass + + # Output types should be a list with all union members + assert set(process.output_types) == {str, int, bool} + + def test_executor_explicit_types_precedence_over_introspection(self): + """Test that explicit types always take precedence over introspected types.""" + + # Introspection would give: input=str, output=[int] + # Explicit gives: input=bytes, output=[float] + @executor(input_type=bytes, output_type=float) + async def process(message: str, ctx: WorkflowContext[int]) -> None: + pass + + # Should use explicit input type (bytes), not introspected (str) + assert bytes in process._handlers + assert str not in process._handlers + + # Should use explicit output type (float), not introspected (int) + assert float in process.output_types + assert int not in process.output_types + + def test_executor_fallback_to_introspection_when_no_explicit_types(self): + """Test that introspection is used when no explicit types are provided.""" + + @executor + async def process(message: str, ctx: WorkflowContext[int]) -> None: + pass + + # Should use introspected types + assert str in process._handlers + assert int in process.output_types + + def test_executor_partial_explicit_types(self): + """Test that partial explicit types work (only input_type or only output_type).""" + + # Only explicit input_type, introspect output_type + @executor(input_type=bytes) + async def process_input(message: str, ctx: WorkflowContext[int]) -> None: + pass + + assert bytes in process_input._handlers # Explicit + assert int in process_input.output_types # Introspected + + # Only explicit output_type, introspect input_type + @executor(output_type=float) + async def process_output(message: str, ctx: WorkflowContext[int]) -> None: + pass + + assert str in process_output._handlers # Introspected + assert float in process_output.output_types # Explicit + assert int not in process_output.output_types # Not introspected when explicit provided + + def test_executor_explicit_input_type_allows_no_message_annotation(self): + """Test that explicit input_type allows function without message type annotation.""" + + @executor(input_type=str) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should work with explicit input_type + assert str in process._handlers + assert process.can_handle(Message(data="hello", source_id="mock")) + + def test_executor_explicit_types_with_id(self): + """Test that explicit types work together with id parameter.""" + + @executor(id="custom_id", input_type=bytes, output_type=int) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + assert process.id == "custom_id" + assert bytes in process._handlers + assert int in process.output_types + + def test_executor_explicit_types_with_single_param_function(self): + """Test that explicit input_type works with single-parameter functions.""" + + @executor(input_type=str) + async def process(message): # type: ignore[no-untyped-def] + return message.upper() + + # Should work with explicit input_type + assert str in process._handlers + assert process.can_handle(Message(data="hello", source_id="mock")) + assert not process.can_handle(Message(data=42, source_id="mock")) + + def test_executor_explicit_types_with_sync_function(self): + """Test that explicit types work with synchronous functions.""" + + @executor(input_type=int, output_type=str) + def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + assert int in process._handlers + assert str in process.output_types + + def test_function_executor_constructor_with_explicit_types(self): + """Test FunctionExecutor constructor with explicit input_type and output_type.""" + + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + func_exec = FunctionExecutor(process, id="test", input_type=dict, output_type=list) + + assert dict in func_exec._handlers + spec = func_exec._handler_specs[0] + assert spec["message_type"] is dict + assert spec["output_types"] == [list] + + def test_executor_explicit_union_types_via_typing_union(self): + """Test that Union[] syntax also works for explicit types.""" + from typing import Union + + @executor(input_type=Union[str, int], output_type=Union[bool, float]) + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Can handle both str and int + assert process.can_handle(Message(data="hello", source_id="mock")) + assert process.can_handle(Message(data=42, source_id="mock")) + + # Output types should include both + assert set(process.output_types) == {bool, float} + + def test_executor_with_string_forward_reference_input_type(self): + """Test that string forward references work for input_type.""" + + @executor(input_type="FuncExecForwardRefMessage") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve the string to the actual type + assert FuncExecForwardRefMessage in process._handlers + assert process.can_handle(Message(data=FuncExecForwardRefMessage("hello"), source_id="mock")) + + def test_executor_with_string_forward_reference_union(self): + """Test that string forward references work with union types.""" + + @executor(input_type="FuncExecForwardRefTypeA | FuncExecForwardRefTypeB") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should handle both types + assert process.can_handle(Message(data=FuncExecForwardRefTypeA("hello"), source_id="mock")) + assert process.can_handle(Message(data=FuncExecForwardRefTypeB(42), source_id="mock")) + + def test_executor_with_string_forward_reference_output_type(self): + """Test that string forward references work for output_type.""" + + @executor(input_type=str, output_type="FuncExecForwardRefResponse") + async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] + pass + + # Should resolve the string output type + assert FuncExecForwardRefResponse in process.output_types diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index 4294f35f4b..3e8d1051e7 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -1,16 +1,153 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union + +import pytest from agent_framework import RequestInfoEvent from agent_framework._workflows._typing_utils import ( deserialize_type, is_instance_of, is_type_compatible, + normalize_type_to_list, + resolve_type_annotation, serialize_type, ) +# region: normalize_type_to_list tests + + +def test_normalize_type_to_list_single_type() -> None: + """Test normalize_type_to_list with single types.""" + assert normalize_type_to_list(str) == [str] + assert normalize_type_to_list(int) == [int] + assert normalize_type_to_list(float) == [float] + assert normalize_type_to_list(bool) == [bool] + assert normalize_type_to_list(list) == [list] + assert normalize_type_to_list(dict) == [dict] + + +def test_normalize_type_to_list_none() -> None: + """Test normalize_type_to_list with None returns empty list.""" + assert normalize_type_to_list(None) == [] + + +def test_normalize_type_to_list_union_pipe_syntax() -> None: + """Test normalize_type_to_list with union types using | syntax.""" + result = normalize_type_to_list(str | int) + assert set(result) == {str, int} + + result = normalize_type_to_list(str | int | bool) + assert set(result) == {str, int, bool} + + +def test_normalize_type_to_list_union_typing_syntax() -> None: + """Test normalize_type_to_list with Union[] from typing module.""" + result = normalize_type_to_list(Union[str, int]) + assert set(result) == {str, int} + + result = normalize_type_to_list(Union[str, int, bool]) + assert set(result) == {str, int, bool} + + +def test_normalize_type_to_list_optional() -> None: + """Test normalize_type_to_list with Optional types (Union[T, None]).""" + # Optional[str] is Union[str, None] + result = normalize_type_to_list(Optional[str]) + assert str in result + assert type(None) in result + assert len(result) == 2 + + # str | None is equivalent + result = normalize_type_to_list(str | None) + assert str in result + assert type(None) in result + assert len(result) == 2 + + +def test_normalize_type_to_list_custom_types() -> None: + """Test normalize_type_to_list with custom class types.""" + + @dataclass + class CustomMessage: + content: str + + result = normalize_type_to_list(CustomMessage) + assert result == [CustomMessage] + + result = normalize_type_to_list(CustomMessage | str) + assert set(result) == {CustomMessage, str} + + +# endregion: normalize_type_to_list tests + + +# region: resolve_type_annotation tests + + +def test_resolve_type_annotation_none() -> None: + """Test resolve_type_annotation with None returns None.""" + assert resolve_type_annotation(None) is None + + +def test_resolve_type_annotation_actual_types() -> None: + """Test resolve_type_annotation passes through actual types unchanged.""" + assert resolve_type_annotation(str) is str + assert resolve_type_annotation(int) is int + assert resolve_type_annotation(str | int) == str | int + + +def test_resolve_type_annotation_string_builtin() -> None: + """Test resolve_type_annotation resolves string references to builtin types.""" + result = resolve_type_annotation("str", {"str": str}) + assert result is str + + result = resolve_type_annotation("int", {"int": int}) + assert result is int + + +def test_resolve_type_annotation_string_union() -> None: + """Test resolve_type_annotation resolves string union types.""" + result = resolve_type_annotation("str | int", {"str": str, "int": int}) + assert result == str | int + + +def test_resolve_type_annotation_string_custom_type() -> None: + """Test resolve_type_annotation resolves string references to custom types.""" + + @dataclass + class MyCustomType: + value: int + + result = resolve_type_annotation("MyCustomType", {"MyCustomType": MyCustomType}) + assert result is MyCustomType + + result = resolve_type_annotation("MyCustomType | str", {"MyCustomType": MyCustomType, "str": str}) + assert set(result.__args__) == {MyCustomType, str} # type: ignore[union-attr] + + +def test_resolve_type_annotation_string_typing_union() -> None: + """Test resolve_type_annotation resolves Union[] syntax in strings.""" + result = resolve_type_annotation("Union[str, int]", {"str": str, "int": int}) + assert set(result.__args__) == {str, int} # type: ignore[union-attr] + + +def test_resolve_type_annotation_string_optional() -> None: + """Test resolve_type_annotation resolves Optional[] syntax in strings.""" + result = resolve_type_annotation("Optional[str]", {"str": str}) + assert str in result.__args__ # type: ignore[union-attr] + assert type(None) in result.__args__ # type: ignore[union-attr] + + +def test_resolve_type_annotation_unresolvable_raises() -> None: + """Test resolve_type_annotation raises NameError for unresolvable types.""" + with pytest.raises(NameError, match="Could not resolve type annotation"): + resolve_type_annotation("NonExistentType", {}) + + +# endregion: resolve_type_annotation tests + def test_basic_types() -> None: """Test basic built-in types.""" diff --git a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py index b5c80062dd..d070173885 100644 --- a/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py +++ b/python/samples/getting_started/workflows/_start-here/step1_executors_and_edges.py @@ -33,6 +33,16 @@ Simple steps can use this form; a terminal step can yield output using ctx.yield_output() to provide workflow results. +- Explicit type parameters with @handler: + Instead of relying on type introspection from function signatures, you can explicitly + specify `input_type` and/or `output_type` on the @handler decorator. These explicit + types take precedence over introspection and support union types (e.g., `str | int`). + + Examples: + @handler(input_type=str | int) # Accepts str or int, output from introspection + @handler(output_type=str | int) # Input from introspection, outputs str or int + @handler(input_type=str, output_type=int) # Both explicitly specified + - Fluent WorkflowBuilder API: add_edge(A, B) to connect nodes, set_start_executor(A), then build() -> Workflow. @@ -45,8 +55,8 @@ """ -# Example 1: A custom Executor subclass -# ------------------------------------ +# Example 1: A custom Executor subclass using introspection (traditional approach) +# --------------------------------------------------------------------------------- # # Subclassing Executor lets you define a named node with lifecycle hooks if needed. # The work itself is implemented in an async method decorated with @handler. @@ -70,14 +80,15 @@ async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None: Note: The WorkflowContext is parameterized with the type this handler will emit. Here WorkflowContext[str] means downstream nodes should expect str. """ + result = text.upper() # Send the result to the next executor in the workflow. await ctx.send_message(result) -# Example 2: A standalone function-based executor -# ----------------------------------------------- +# Example 2: A standalone function-based executor using introspection +# -------------------------------------------------------------------- # # For simple steps you can skip subclassing and define an async function with the # same signature pattern (typed input + WorkflowContext[T_Out, T_W_Out]) and decorate it with @@ -101,30 +112,94 @@ async def reverse_text(text: str, ctx: WorkflowContext[Never, str]) -> None: await ctx.yield_output(result) +# Example 3: Using explicit type parameters on @handler +# ----------------------------------------------------- +# +# Instead of relying on type introspection, you can explicitly specify input_type +# and/or output_type on the @handler decorator. These take precedence over introspection +# and support union types (e.g., str | int). +# +# This is useful when: +# - You want to accept multiple types (union types) without complex type annotations +# - The function signature uses Any or a base type for flexibility +# - You want to decouple the runtime type routing from the static type annotations + + +class ExclamationAdder(Executor): + """An executor that adds exclamation marks, demonstrating explicit @handler types. + + This example shows how to use explicit input_type and output_type parameters + on the @handler decorator instead of relying on introspection from the function + signature. This approach is especially useful for union types. + """ + + def __init__(self, id: str): + super().__init__(id=id) + + @handler(input_type=str, output_type=str) + async def add_exclamation(self, message: str, ctx: WorkflowContext) -> None: + """Add exclamation marks to the input. + + Note: The input_type=str and output_type=str are explicitly specified on @handler, + so the framework uses those instead of introspecting the function signature. + The WorkflowContext here has no type parameters because the explicit types + on @handler take precedence. + """ + result = f"{message}!!!" + await ctx.send_message(result) + + async def main(): - """Build and run a simple 2-step workflow using the fluent builder API.""" + """Build and run workflows using the fluent builder API.""" + # Workflow 1: Using introspection-based type detection + # ----------------------------------------------------- upper_case = UpperCase(id="upper_case_executor") # Build the workflow using a fluent pattern: # 1) add_edge(from_node, to_node) defines a directed edge upper_case -> reverse_text # 2) set_start_executor(node) declares the entry point # 3) build() finalizes and returns an immutable Workflow object - workflow = WorkflowBuilder().add_edge(upper_case, reverse_text).set_start_executor(upper_case).build() + workflow1 = WorkflowBuilder().add_edge(upper_case, reverse_text).set_start_executor(upper_case).build() # Run the workflow by sending the initial message to the start node. # The run(...) call returns an event collection; its get_outputs() method # retrieves the outputs yielded by any terminal nodes. - events = await workflow.run("hello world") - print(events.get_outputs()) - # Summarize the final run state (e.g., IDLE) - print("Final state:", events.get_final_state()) + print("Workflow 1 (introspection-based types):") + events1 = await workflow1.run("hello world") + print(events1.get_outputs()) + print("Final state:", events1.get_final_state()) + + # Workflow 2: Using explicit type parameters on @handler + # ------------------------------------------------------- + exclamation_adder = ExclamationAdder(id="exclamation_adder") + + # This workflow demonstrates the explicit input_type/output_type feature: + # exclamation_adder uses @handler(input_type=str, output_type=str) to + # explicitly declare types instead of relying on introspection. + workflow2 = ( + WorkflowBuilder() + .add_edge(upper_case, exclamation_adder) + .add_edge(exclamation_adder, reverse_text) + .set_start_executor(upper_case) + .build() + ) + + print("\nWorkflow 2 (explicit @handler types):") + events2 = await workflow2.run("hello world") + print(events2.get_outputs()) + print("Final state:", events2.get_final_state()) """ Sample Output: + Workflow 1 (introspection-based types): ['DLROW OLLEH'] Final state: WorkflowRunState.IDLE + + Workflow 2 (explicit @handler types): + ['!!!DLROW OLLEH'] + Final state: WorkflowRunState.IDLE """