From 026678030146666eebfe7e0e83d60facd27fc85a Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi <mpippi@gmail.com> Date: Thu, 27 Feb 2025 17:54:22 +0100 Subject: [PATCH] feat: auto-detect custom start and stop events in workflow classes (#17865) --- .../llama_index/core/workflow/workflow.py | 66 ++++++--- llama-index-core/tests/workflow/conftest.py | 2 +- .../tests/workflow/test_decorator.py | 9 +- .../tests/workflow/test_service.py | 19 ++- .../tests/workflow/test_workflow.py | 128 ++++++++++++------ 5 files changed, 151 insertions(+), 73 deletions(-) diff --git a/llama-index-core/llama_index/core/workflow/workflow.py b/llama-index-core/llama_index/core/workflow/workflow.py index bf4c322868..98c10df4ab 100644 --- a/llama-index-core/llama_index/core/workflow/workflow.py +++ b/llama-index-core/llama_index/core/workflow/workflow.py @@ -68,8 +68,6 @@ class Workflow(metaclass=WorkflowMeta): verbose: bool = False, service_manager: Optional[ServiceManager] = None, num_concurrent_runs: Optional[int] = None, - stop_event_class: type[StopEvent] = StopEvent, - start_event_class: type[StartEvent] = StartEvent, ) -> None: """Create an instance of the workflow. @@ -89,24 +87,14 @@ class Workflow(metaclass=WorkflowMeta): num_concurrent_runs: maximum number of .run() executions occurring simultaneously. If set to `None`, there is no limit to this number. - stop_event_class: - custom type used instead of StopEvent - start_event_class: - custom type used instead of StartEvent """ # Configuration self._timeout = timeout self._verbose = verbose self._disable_validation = disable_validation self._num_concurrent_runs = num_concurrent_runs - self._stop_event_class = stop_event_class - if not issubclass(self._stop_event_class, StopEvent): - msg = f"Stop event class '{stop_event_class.__name__}' must derive from 'StopEvent'" - raise WorkflowConfigurationError(msg) - self._start_event_class = start_event_class - if not issubclass(self._start_event_class, StartEvent): - msg = f"Start event class '{start_event_class.__name__}' must derive from 'StartEvent'" - raise WorkflowConfigurationError(msg) + self._stop_event_class = self._ensure_stop_event_class() + self._start_event_class = self._ensure_start_event_class() self._sem = ( asyncio.Semaphore(num_concurrent_runs) if num_concurrent_runs else None ) @@ -116,6 +104,50 @@ class Workflow(metaclass=WorkflowMeta): # Services management self._service_manager = service_manager or ServiceManager() + def _ensure_start_event_class(self) -> type[StartEvent]: + """Returns the StartEvent type used in this workflow. + + It works by inspecting the events received by the step methods. + """ + start_events_found: set[type[StartEvent]] = set() + for step_func in self._get_steps().values(): + step_config: StepConfig = getattr(step_func, "__step_config") + for event_type in step_config.accepted_events: + if issubclass(event_type, StartEvent): + start_events_found.add(event_type) + + num_found = len(start_events_found) + if num_found == 0: + msg = "At least one Event of type StartEvent must be received by any step." + raise WorkflowConfigurationError(msg) + elif num_found > 1: + msg = f"Only one type of StartEvent is allowed per workflow, found {num_found}: {start_events_found}." + raise WorkflowConfigurationError(msg) + else: + return start_events_found.pop() + + def _ensure_stop_event_class(self) -> type[StopEvent]: + """Returns the StopEvent type used in this workflow. + + It works by inspecting the events returned. + """ + stop_events_found: set[type[StopEvent]] = set() + for step_func in self._get_steps().values(): + step_config: StepConfig = getattr(step_func, "__step_config") + for event_type in step_config.return_types: + if issubclass(event_type, StopEvent): + stop_events_found.add(event_type) + + num_found = len(stop_events_found) + if num_found == 0: + msg = "At least one Event of type StopEvent must be returned by any step." + raise WorkflowConfigurationError(msg) + elif num_found > 1: + msg = f"Only one type of StopEvent is allowed per workflow, found {num_found}: {stop_events_found}." + raise WorkflowConfigurationError(msg) + else: + return stop_events_found.pop() + async def stream_events(self) -> AsyncGenerator[Event, None]: """Returns an async generator to consume any event that workflow steps decide to stream. @@ -402,7 +434,7 @@ class Workflow(metaclass=WorkflowMeta): ctx = next(iter(self._contexts)) ctx.send_event(message=message, step=step) - def _get_start_event( + def _get_start_event_instance( self, start_event: Optional[StartEvent], **kwargs: Any ) -> StartEvent: if start_event is not None: @@ -459,7 +491,9 @@ class Workflow(metaclass=WorkflowMeta): try: if not ctx.is_running: # Send the first event - start_event_instance = self._get_start_event(start_event, **kwargs) + start_event_instance = self._get_start_event_instance( + start_event, **kwargs + ) ctx.send_event(start_event_instance) # the context is now running diff --git a/llama-index-core/tests/workflow/conftest.py b/llama-index-core/tests/workflow/conftest.py index 80658991f2..6e7f1a45bc 100644 --- a/llama-index-core/tests/workflow/conftest.py +++ b/llama-index-core/tests/workflow/conftest.py @@ -44,4 +44,4 @@ def events(): @pytest.fixture() def ctx(): - return Context(workflow=Workflow()) + return Context(workflow=DummyWorkflow()) diff --git a/llama-index-core/tests/workflow/test_decorator.py b/llama-index-core/tests/workflow/test_decorator.py index 48559afcf9..3a072ef52b 100644 --- a/llama-index-core/tests/workflow/test_decorator.py +++ b/llama-index-core/tests/workflow/test_decorator.py @@ -1,10 +1,9 @@ import re import pytest - from llama_index.core.workflow.decorators import step from llama_index.core.workflow.errors import WorkflowValidationError -from llama_index.core.workflow.events import Event +from llama_index.core.workflow.events import Event, StartEvent, StopEvent from llama_index.core.workflow.workflow import Workflow @@ -22,12 +21,12 @@ def test_decorated_config(workflow): def test_decorate_method(): class TestWorkflow(Workflow): @step - def f1(self, ev: Event) -> Event: + def f1(self, ev: StartEvent) -> Event: return ev @step - def f2(self, ev: Event) -> Event: - return ev + def f2(self, ev: Event) -> StopEvent: + return StopEvent() wf = TestWorkflow() assert getattr(wf.f1, "__step_config") diff --git a/llama-index-core/tests/workflow/test_service.py b/llama-index-core/tests/workflow/test_service.py index 147b5fe3ad..f818652196 100644 --- a/llama-index-core/tests/workflow/test_service.py +++ b/llama-index-core/tests/workflow/test_service.py @@ -1,10 +1,9 @@ import pytest - +from llama_index.core.workflow.context import Context from llama_index.core.workflow.decorators import step from llama_index.core.workflow.events import Event, StartEvent, StopEvent -from llama_index.core.workflow.workflow import Workflow -from llama_index.core.workflow.context import Context from llama_index.core.workflow.service import ServiceManager, ServiceNotFoundError +from llama_index.core.workflow.workflow import Workflow class ServiceWorkflow(Workflow): @@ -67,17 +66,15 @@ async def test_default_value_for_service(): assert res == 84 -def test_service_manager_add(): +def test_service_manager_add(workflow): s = ServiceManager() - w = Workflow() - s.add("test_id", w) - assert s._services["test_id"] == w + s.add("test_id", workflow) + assert s._services["test_id"] == workflow -def test_service_manager_get(): +def test_service_manager_get(workflow): s = ServiceManager() - w = Workflow() - s._services["test_id"] = w - assert s.get("test_id") == w + s._services["test_id"] = workflow + assert s.get("test_id") == workflow with pytest.raises(ServiceNotFoundError): s.get("not_found") diff --git a/llama-index-core/tests/workflow/test_workflow.py b/llama-index-core/tests/workflow/test_workflow.py index e8951fac62..da5c4ac3c2 100644 --- a/llama-index-core/tests/workflow/test_workflow.py +++ b/llama-index-core/tests/workflow/test_workflow.py @@ -34,6 +34,14 @@ class EventWithName(Event): name: str +class MyStart(StartEvent): + query: str + + +class MyStop(StopEvent): + outcome: str + + def test_fn(): print("test_fn") @@ -126,12 +134,11 @@ async def test_workflow_validation_unproduced_events(): async def invalid_step(self, ev: StartEvent) -> None: pass - workflow = InvalidWorkflow() with pytest.raises( - WorkflowValidationError, - match="No event of type StopEvent is produced.", + WorkflowConfigurationError, + match="At least one Event of type StopEvent must be returned by any step.", ): - await workflow.run() + workflow = InvalidWorkflow() @pytest.mark.asyncio() @@ -164,12 +171,11 @@ async def test_workflow_validation_start_event_not_consumed(): async def another_step(self, ev: OneTestEvent) -> OneTestEvent: return OneTestEvent() - workflow = InvalidWorkflow() with pytest.raises( - WorkflowValidationError, - match="The following events are produced but never consumed: StartEvent", + WorkflowConfigurationError, + match="At least one Event of type StartEvent must be received by any step.", ): - await workflow.run() + workflow = InvalidWorkflow() @pytest.mark.asyncio() @@ -221,7 +227,7 @@ async def test_workflow_num_workers(): ctx.send_event(OneTestEvent(test_param="test3")) # send one extra event - ctx.session.send_event(AnotherTestEvent(another_test_param="test4")) + ctx.send_event(AnotherTestEvent(another_test_param="test4")) return LastEvent() @@ -362,22 +368,21 @@ async def test_workflow_multiple_runs(): assert set(results) == {6, 84, -198} -def test_deprecated_send_event(): +def test_deprecated_send_event(workflow): ev = StartEvent() - wf = Workflow() ctx = mock.MagicMock() # One context, assert step emits a warning - wf._contexts.add(ctx) + workflow._contexts.add(ctx) with pytest.warns(UserWarning): - wf.send_event(message=ev) + workflow.send_event(message=ev) ctx.send_event.assert_called_with(message=ev, step=None) # Second context, assert step raises an exception ctx = mock.MagicMock() - wf._contexts.add(ctx) + workflow._contexts.add(ctx) with pytest.raises(WorkflowRuntimeError): - wf.send_event(message=ev) + workflow.send_event(message=ev) ctx.send_event.assert_not_called() @@ -443,7 +448,12 @@ async def test_workflow_task_raises_step(): def test_workflow_disable_validation(): - w = Workflow(disable_validation=True) + class DummyWorkflow(Workflow): + @step + async def step(self, ev: StartEvent) -> StopEvent: + raise ValueError("The step raised an error!") + + w = DummyWorkflow(disable_validation=True) w._get_steps = mock.MagicMock() w._validate() w._get_steps.assert_not_called() @@ -722,12 +732,6 @@ async def test_workflow_run_num_concurrent( @pytest.mark.asyncio() async def test_custom_stop_event(): - class MyStart(StartEvent): - query: str - - class MyStop(StopEvent): - outcome: str - class CustomEventsWorkflow(Workflow): @step async def start_step(self, ev: MyStart) -> OneTestEvent: @@ -741,7 +745,13 @@ async def test_custom_stop_event(): async def end_step(self, ev: LastEvent) -> MyStop: return MyStop(outcome="Workflow completed") - wf = CustomEventsWorkflow(start_event_class=MyStart, stop_event_class=MyStop) + wf = CustomEventsWorkflow() + assert wf._start_event_class == MyStart + assert wf._stop_event_class == MyStop + result = await wf.run(query="foo") + + # Ensure the event types can be inferred when not passed to the init + wf = CustomEventsWorkflow() assert wf._start_event_class == MyStart assert wf._stop_event_class == MyStop result = await wf.run(query="foo") @@ -754,48 +764,86 @@ def test_is_done(workflow): def test_wrong_event_types(): - class CustomEvent(Event): + class RandomEvent(Event): pass + class InvalidStopWorkflow(Workflow): + @step + async def a_step(self, ev: MyStart) -> RandomEvent: + return RandomEvent() + with pytest.raises( WorkflowConfigurationError, - match="Start event class 'CustomEvent' must derive from 'StartEvent'", + match="At least one Event of type StopEvent must be returned by any step.", ): - DummyWorkflow(start_event_class=CustomEvent) # type: ignore + InvalidStopWorkflow() + + class InvalidStartWorkflow(Workflow): + @step + async def a_step(self, ev: RandomEvent) -> StopEvent: + return StopEvent() with pytest.raises( WorkflowConfigurationError, - match="Stop event class 'CustomEvent' must derive from 'StopEvent'", + match="At least one Event of type StartEvent must be received by any step.", ): - DummyWorkflow(stop_event_class=CustomEvent) # type: ignore + InvalidStartWorkflow() -def test__get_start_event(caplog): +def test__get_start_event_instance(caplog): class CustomEvent(StartEvent): field: str e = CustomEvent(field="test") - d = DummyWorkflow(start_event_class=CustomEvent) - - # Invoke run() with wrong start_event type - with pytest.raises( - ValueError, - match="The 'start_event' argument must be an instance of 'StartEvent'.", - ): - d._get_start_event(start_event="wrong type", arg="foo") # type: ignore + d = DummyWorkflow() + d._start_event_class = CustomEvent # Invoke run() passing a legit start event but with additional kwargs with caplog.at_level(logging.WARN): - assert d._get_start_event(e, this_will_be_ignored=True) == e + assert d._get_start_event_instance(e, this_will_be_ignored=True) == e assert ( "Keyword arguments are not supported when 'run()' is invoked with the 'start_event' parameter." in caplog.text ) # Old style kwargs passed to the designed StartEvent - assert type(d._get_start_event(None, field="test")) is CustomEvent + assert type(d._get_start_event_instance(None, field="test")) is CustomEvent # Old style but wrong kwargs passed to the designed StartEvent err = "Failed creating a start event of type 'CustomEvent' with the keyword arguments: {'wrong_field': 'test'}" with pytest.raises(WorkflowRuntimeError, match=err): - d._get_start_event(None, wrong_field="test") + d._get_start_event_instance(None, wrong_field="test") + + +def test__ensure_start_event_class_multiple_types(): + class DummyWorkflow(Workflow): + @step + def one(self, ev: MyStart) -> None: + pass + + @step + def two(self, ev: StartEvent) -> StopEvent: + return StopEvent() + + with pytest.raises( + WorkflowConfigurationError, + match="Only one type of StartEvent is allowed per workflow, found 2", + ): + wf = DummyWorkflow() + + +def test__ensure_stop_event_class_multiple_types(): + class DummyWorkflow(Workflow): + @step + def one(self, ev: MyStart) -> MyStop: + return MyStop(outcome="nope") + + @step + def two(self, ev: MyStart) -> StopEvent: + return StopEvent() + + with pytest.raises( + WorkflowConfigurationError, + match="Only one type of StopEvent is allowed per workflow, found 2", + ): + wf = DummyWorkflow() -- GitLab