From 026678030146666eebfe7e0e83d60facd27fc85a Mon Sep 17 00:00:00 2001
From: Massimiliano Pippi <mpippi@gmail.com>
Date: Thu, 27 Feb 2025 17:54:22 +0100
Subject: [PATCH] feat: auto-detect custom start and stop events in workflow
 classes (#17865)

---
 .../llama_index/core/workflow/workflow.py     |  66 ++++++---
 llama-index-core/tests/workflow/conftest.py   |   2 +-
 .../tests/workflow/test_decorator.py          |   9 +-
 .../tests/workflow/test_service.py            |  19 ++-
 .../tests/workflow/test_workflow.py           | 128 ++++++++++++------
 5 files changed, 151 insertions(+), 73 deletions(-)

diff --git a/llama-index-core/llama_index/core/workflow/workflow.py b/llama-index-core/llama_index/core/workflow/workflow.py
index bf4c322868..98c10df4ab 100644
--- a/llama-index-core/llama_index/core/workflow/workflow.py
+++ b/llama-index-core/llama_index/core/workflow/workflow.py
@@ -68,8 +68,6 @@ class Workflow(metaclass=WorkflowMeta):
         verbose: bool = False,
         service_manager: Optional[ServiceManager] = None,
         num_concurrent_runs: Optional[int] = None,
-        stop_event_class: type[StopEvent] = StopEvent,
-        start_event_class: type[StartEvent] = StartEvent,
     ) -> None:
         """Create an instance of the workflow.
 
@@ -89,24 +87,14 @@ class Workflow(metaclass=WorkflowMeta):
             num_concurrent_runs:
                 maximum number of .run() executions occurring simultaneously. If set to `None`, there
                 is no limit to this number.
-            stop_event_class:
-                custom type used instead of StopEvent
-            start_event_class:
-                custom type used instead of StartEvent
         """
         # Configuration
         self._timeout = timeout
         self._verbose = verbose
         self._disable_validation = disable_validation
         self._num_concurrent_runs = num_concurrent_runs
-        self._stop_event_class = stop_event_class
-        if not issubclass(self._stop_event_class, StopEvent):
-            msg = f"Stop event class '{stop_event_class.__name__}' must derive from 'StopEvent'"
-            raise WorkflowConfigurationError(msg)
-        self._start_event_class = start_event_class
-        if not issubclass(self._start_event_class, StartEvent):
-            msg = f"Start event class '{start_event_class.__name__}' must derive from 'StartEvent'"
-            raise WorkflowConfigurationError(msg)
+        self._stop_event_class = self._ensure_stop_event_class()
+        self._start_event_class = self._ensure_start_event_class()
         self._sem = (
             asyncio.Semaphore(num_concurrent_runs) if num_concurrent_runs else None
         )
@@ -116,6 +104,50 @@ class Workflow(metaclass=WorkflowMeta):
         # Services management
         self._service_manager = service_manager or ServiceManager()
 
+    def _ensure_start_event_class(self) -> type[StartEvent]:
+        """Returns the StartEvent type used in this workflow.
+
+        It works by inspecting the events received by the step methods.
+        """
+        start_events_found: set[type[StartEvent]] = set()
+        for step_func in self._get_steps().values():
+            step_config: StepConfig = getattr(step_func, "__step_config")
+            for event_type in step_config.accepted_events:
+                if issubclass(event_type, StartEvent):
+                    start_events_found.add(event_type)
+
+        num_found = len(start_events_found)
+        if num_found == 0:
+            msg = "At least one Event of type StartEvent must be received by any step."
+            raise WorkflowConfigurationError(msg)
+        elif num_found > 1:
+            msg = f"Only one type of StartEvent is allowed per workflow, found {num_found}: {start_events_found}."
+            raise WorkflowConfigurationError(msg)
+        else:
+            return start_events_found.pop()
+
+    def _ensure_stop_event_class(self) -> type[StopEvent]:
+        """Returns the StopEvent type used in this workflow.
+
+        It works by inspecting the events returned.
+        """
+        stop_events_found: set[type[StopEvent]] = set()
+        for step_func in self._get_steps().values():
+            step_config: StepConfig = getattr(step_func, "__step_config")
+            for event_type in step_config.return_types:
+                if issubclass(event_type, StopEvent):
+                    stop_events_found.add(event_type)
+
+        num_found = len(stop_events_found)
+        if num_found == 0:
+            msg = "At least one Event of type StopEvent must be returned by any step."
+            raise WorkflowConfigurationError(msg)
+        elif num_found > 1:
+            msg = f"Only one type of StopEvent is allowed per workflow, found {num_found}: {stop_events_found}."
+            raise WorkflowConfigurationError(msg)
+        else:
+            return stop_events_found.pop()
+
     async def stream_events(self) -> AsyncGenerator[Event, None]:
         """Returns an async generator to consume any event that workflow steps decide to stream.
 
@@ -402,7 +434,7 @@ class Workflow(metaclass=WorkflowMeta):
         ctx = next(iter(self._contexts))
         ctx.send_event(message=message, step=step)
 
-    def _get_start_event(
+    def _get_start_event_instance(
         self, start_event: Optional[StartEvent], **kwargs: Any
     ) -> StartEvent:
         if start_event is not None:
@@ -459,7 +491,9 @@ class Workflow(metaclass=WorkflowMeta):
             try:
                 if not ctx.is_running:
                     # Send the first event
-                    start_event_instance = self._get_start_event(start_event, **kwargs)
+                    start_event_instance = self._get_start_event_instance(
+                        start_event, **kwargs
+                    )
                     ctx.send_event(start_event_instance)
 
                     # the context is now running
diff --git a/llama-index-core/tests/workflow/conftest.py b/llama-index-core/tests/workflow/conftest.py
index 80658991f2..6e7f1a45bc 100644
--- a/llama-index-core/tests/workflow/conftest.py
+++ b/llama-index-core/tests/workflow/conftest.py
@@ -44,4 +44,4 @@ def events():
 
 @pytest.fixture()
 def ctx():
-    return Context(workflow=Workflow())
+    return Context(workflow=DummyWorkflow())
diff --git a/llama-index-core/tests/workflow/test_decorator.py b/llama-index-core/tests/workflow/test_decorator.py
index 48559afcf9..3a072ef52b 100644
--- a/llama-index-core/tests/workflow/test_decorator.py
+++ b/llama-index-core/tests/workflow/test_decorator.py
@@ -1,10 +1,9 @@
 import re
 
 import pytest
-
 from llama_index.core.workflow.decorators import step
 from llama_index.core.workflow.errors import WorkflowValidationError
-from llama_index.core.workflow.events import Event
+from llama_index.core.workflow.events import Event, StartEvent, StopEvent
 from llama_index.core.workflow.workflow import Workflow
 
 
@@ -22,12 +21,12 @@ def test_decorated_config(workflow):
 def test_decorate_method():
     class TestWorkflow(Workflow):
         @step
-        def f1(self, ev: Event) -> Event:
+        def f1(self, ev: StartEvent) -> Event:
             return ev
 
         @step
-        def f2(self, ev: Event) -> Event:
-            return ev
+        def f2(self, ev: Event) -> StopEvent:
+            return StopEvent()
 
     wf = TestWorkflow()
     assert getattr(wf.f1, "__step_config")
diff --git a/llama-index-core/tests/workflow/test_service.py b/llama-index-core/tests/workflow/test_service.py
index 147b5fe3ad..f818652196 100644
--- a/llama-index-core/tests/workflow/test_service.py
+++ b/llama-index-core/tests/workflow/test_service.py
@@ -1,10 +1,9 @@
 import pytest
-
+from llama_index.core.workflow.context import Context
 from llama_index.core.workflow.decorators import step
 from llama_index.core.workflow.events import Event, StartEvent, StopEvent
-from llama_index.core.workflow.workflow import Workflow
-from llama_index.core.workflow.context import Context
 from llama_index.core.workflow.service import ServiceManager, ServiceNotFoundError
+from llama_index.core.workflow.workflow import Workflow
 
 
 class ServiceWorkflow(Workflow):
@@ -67,17 +66,15 @@ async def test_default_value_for_service():
     assert res == 84
 
 
-def test_service_manager_add():
+def test_service_manager_add(workflow):
     s = ServiceManager()
-    w = Workflow()
-    s.add("test_id", w)
-    assert s._services["test_id"] == w
+    s.add("test_id", workflow)
+    assert s._services["test_id"] == workflow
 
 
-def test_service_manager_get():
+def test_service_manager_get(workflow):
     s = ServiceManager()
-    w = Workflow()
-    s._services["test_id"] = w
-    assert s.get("test_id") == w
+    s._services["test_id"] = workflow
+    assert s.get("test_id") == workflow
     with pytest.raises(ServiceNotFoundError):
         s.get("not_found")
diff --git a/llama-index-core/tests/workflow/test_workflow.py b/llama-index-core/tests/workflow/test_workflow.py
index e8951fac62..da5c4ac3c2 100644
--- a/llama-index-core/tests/workflow/test_workflow.py
+++ b/llama-index-core/tests/workflow/test_workflow.py
@@ -34,6 +34,14 @@ class EventWithName(Event):
     name: str
 
 
+class MyStart(StartEvent):
+    query: str
+
+
+class MyStop(StopEvent):
+    outcome: str
+
+
 def test_fn():
     print("test_fn")
 
@@ -126,12 +134,11 @@ async def test_workflow_validation_unproduced_events():
         async def invalid_step(self, ev: StartEvent) -> None:
             pass
 
-    workflow = InvalidWorkflow()
     with pytest.raises(
-        WorkflowValidationError,
-        match="No event of type StopEvent is produced.",
+        WorkflowConfigurationError,
+        match="At least one Event of type StopEvent must be returned by any step.",
     ):
-        await workflow.run()
+        workflow = InvalidWorkflow()
 
 
 @pytest.mark.asyncio()
@@ -164,12 +171,11 @@ async def test_workflow_validation_start_event_not_consumed():
         async def another_step(self, ev: OneTestEvent) -> OneTestEvent:
             return OneTestEvent()
 
-    workflow = InvalidWorkflow()
     with pytest.raises(
-        WorkflowValidationError,
-        match="The following events are produced but never consumed: StartEvent",
+        WorkflowConfigurationError,
+        match="At least one Event of type StartEvent must be received by any step.",
     ):
-        await workflow.run()
+        workflow = InvalidWorkflow()
 
 
 @pytest.mark.asyncio()
@@ -221,7 +227,7 @@ async def test_workflow_num_workers():
             ctx.send_event(OneTestEvent(test_param="test3"))
 
             # send one extra event
-            ctx.session.send_event(AnotherTestEvent(another_test_param="test4"))
+            ctx.send_event(AnotherTestEvent(another_test_param="test4"))
 
             return LastEvent()
 
@@ -362,22 +368,21 @@ async def test_workflow_multiple_runs():
     assert set(results) == {6, 84, -198}
 
 
-def test_deprecated_send_event():
+def test_deprecated_send_event(workflow):
     ev = StartEvent()
-    wf = Workflow()
     ctx = mock.MagicMock()
 
     # One context, assert step emits a warning
-    wf._contexts.add(ctx)
+    workflow._contexts.add(ctx)
     with pytest.warns(UserWarning):
-        wf.send_event(message=ev)
+        workflow.send_event(message=ev)
     ctx.send_event.assert_called_with(message=ev, step=None)
 
     # Second context, assert step raises an exception
     ctx = mock.MagicMock()
-    wf._contexts.add(ctx)
+    workflow._contexts.add(ctx)
     with pytest.raises(WorkflowRuntimeError):
-        wf.send_event(message=ev)
+        workflow.send_event(message=ev)
     ctx.send_event.assert_not_called()
 
 
@@ -443,7 +448,12 @@ async def test_workflow_task_raises_step():
 
 
 def test_workflow_disable_validation():
-    w = Workflow(disable_validation=True)
+    class DummyWorkflow(Workflow):
+        @step
+        async def step(self, ev: StartEvent) -> StopEvent:
+            raise ValueError("The step raised an error!")
+
+    w = DummyWorkflow(disable_validation=True)
     w._get_steps = mock.MagicMock()
     w._validate()
     w._get_steps.assert_not_called()
@@ -722,12 +732,6 @@ async def test_workflow_run_num_concurrent(
 
 @pytest.mark.asyncio()
 async def test_custom_stop_event():
-    class MyStart(StartEvent):
-        query: str
-
-    class MyStop(StopEvent):
-        outcome: str
-
     class CustomEventsWorkflow(Workflow):
         @step
         async def start_step(self, ev: MyStart) -> OneTestEvent:
@@ -741,7 +745,13 @@ async def test_custom_stop_event():
         async def end_step(self, ev: LastEvent) -> MyStop:
             return MyStop(outcome="Workflow completed")
 
-    wf = CustomEventsWorkflow(start_event_class=MyStart, stop_event_class=MyStop)
+    wf = CustomEventsWorkflow()
+    assert wf._start_event_class == MyStart
+    assert wf._stop_event_class == MyStop
+    result = await wf.run(query="foo")
+
+    # Ensure the event types can be inferred when not passed to the init
+    wf = CustomEventsWorkflow()
     assert wf._start_event_class == MyStart
     assert wf._stop_event_class == MyStop
     result = await wf.run(query="foo")
@@ -754,48 +764,86 @@ def test_is_done(workflow):
 
 
 def test_wrong_event_types():
-    class CustomEvent(Event):
+    class RandomEvent(Event):
         pass
 
+    class InvalidStopWorkflow(Workflow):
+        @step
+        async def a_step(self, ev: MyStart) -> RandomEvent:
+            return RandomEvent()
+
     with pytest.raises(
         WorkflowConfigurationError,
-        match="Start event class 'CustomEvent' must derive from 'StartEvent'",
+        match="At least one Event of type StopEvent must be returned by any step.",
     ):
-        DummyWorkflow(start_event_class=CustomEvent)  # type: ignore
+        InvalidStopWorkflow()
+
+    class InvalidStartWorkflow(Workflow):
+        @step
+        async def a_step(self, ev: RandomEvent) -> StopEvent:
+            return StopEvent()
 
     with pytest.raises(
         WorkflowConfigurationError,
-        match="Stop event class 'CustomEvent' must derive from 'StopEvent'",
+        match="At least one Event of type StartEvent must be received by any step.",
     ):
-        DummyWorkflow(stop_event_class=CustomEvent)  # type: ignore
+        InvalidStartWorkflow()
 
 
-def test__get_start_event(caplog):
+def test__get_start_event_instance(caplog):
     class CustomEvent(StartEvent):
         field: str
 
     e = CustomEvent(field="test")
-    d = DummyWorkflow(start_event_class=CustomEvent)
-
-    # Invoke run() with wrong start_event type
-    with pytest.raises(
-        ValueError,
-        match="The 'start_event' argument must be an instance of 'StartEvent'.",
-    ):
-        d._get_start_event(start_event="wrong type", arg="foo")  # type: ignore
+    d = DummyWorkflow()
+    d._start_event_class = CustomEvent
 
     # Invoke run() passing a legit start event but with additional kwargs
     with caplog.at_level(logging.WARN):
-        assert d._get_start_event(e, this_will_be_ignored=True) == e
+        assert d._get_start_event_instance(e, this_will_be_ignored=True) == e
         assert (
             "Keyword arguments are not supported when 'run()' is invoked with the 'start_event' parameter."
             in caplog.text
         )
 
     # Old style kwargs passed to the designed StartEvent
-    assert type(d._get_start_event(None, field="test")) is CustomEvent
+    assert type(d._get_start_event_instance(None, field="test")) is CustomEvent
 
     # Old style but wrong kwargs passed to the designed StartEvent
     err = "Failed creating a start event of type 'CustomEvent' with the keyword arguments: {'wrong_field': 'test'}"
     with pytest.raises(WorkflowRuntimeError, match=err):
-        d._get_start_event(None, wrong_field="test")
+        d._get_start_event_instance(None, wrong_field="test")
+
+
+def test__ensure_start_event_class_multiple_types():
+    class DummyWorkflow(Workflow):
+        @step
+        def one(self, ev: MyStart) -> None:
+            pass
+
+        @step
+        def two(self, ev: StartEvent) -> StopEvent:
+            return StopEvent()
+
+    with pytest.raises(
+        WorkflowConfigurationError,
+        match="Only one type of StartEvent is allowed per workflow, found 2",
+    ):
+        wf = DummyWorkflow()
+
+
+def test__ensure_stop_event_class_multiple_types():
+    class DummyWorkflow(Workflow):
+        @step
+        def one(self, ev: MyStart) -> MyStop:
+            return MyStop(outcome="nope")
+
+        @step
+        def two(self, ev: MyStart) -> StopEvent:
+            return StopEvent()
+
+    with pytest.raises(
+        WorkflowConfigurationError,
+        match="Only one type of StopEvent is allowed per workflow, found 2",
+    ):
+        wf = DummyWorkflow()
-- 
GitLab