From 07101b31bd026f14d1bfefbb20b0a1f512c93900 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi <mpippi@gmail.com> Date: Thu, 27 Feb 2025 04:41:35 +0100 Subject: [PATCH] fix: stepwise execution breaks when steps do async work (#17914) --- docs/docs/module_guides/workflow/index.md | 26 ++++++++------- .../llama_index/core/workflow/context.py | 18 ++++++++++ .../llama_index/core/workflow/events.py | 4 +++ .../llama_index/core/workflow/handler.py | 17 +++++++--- .../llama_index/core/workflow/utils.py | 12 ++++--- .../llama_index/core/workflow/workflow.py | 3 ++ .../tests/workflow/test_stepwise.py | 33 +++++++++++++++++++ 7 files changed, 92 insertions(+), 21 deletions(-) create mode 100644 llama-index-core/tests/workflow/test_stepwise.py diff --git a/docs/docs/module_guides/workflow/index.md b/docs/docs/module_guides/workflow/index.md index 3b1f0d7693..ecdc30339b 100644 --- a/docs/docs/module_guides/workflow/index.md +++ b/docs/docs/module_guides/workflow/index.md @@ -428,18 +428,20 @@ final_result = await handler Workflows have built-in utilities for stepwise execution, allowing you to control execution and debug state as things progress. ```python -w = JokeFlow(...) - -# Kick off the workflow -handler = w.run(topic="Pirates") - -# Iterate until done -async for _ in handler: - # inspect context - # val = await handler.ctx.get("key") - continue - -# Get the final result +# Create a workflow, same as usual +w = JokeFlow() +# Get the handler. Passing `stepwise=True` will block execution, waiting for manual intervention +handler = workflow.run(stepwise=True) +# Each time we call `run_step`, the workflow will advance and return the Event +# that was produced in the last step. This event needs to be manually propagated +# for the workflow to keep going (we assign it to `ev` with the := operator). +while ev := await handler.run_step(): + # If we're here, it means there's an event we need to propagate, + # let's do it with `send_event` + handler.ctx.send_event(ev) + +# If we're here, it means the workflow execution completed, and +# we can now access the final result. result = await handler ``` diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index 5c838af536..0e964bc38a 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -53,7 +53,13 @@ class Context: ) self._accepted_events: List[Tuple[str, str]] = [] self._retval: Any = None + # Map the step names that were executed to a list of events they received. + # This will be serialized, and is needed to resume a Workflow run passing + # an existing context. self._in_progress: Dict[str, List[Event]] = defaultdict(list) + # Keep track of the steps currently running. This is only valid when a + # workflow is running and won't be serialized. + self._currently_running_steps: Set[str] = set() # Streaming machinery self._streaming_queue: asyncio.Queue = asyncio.Queue() # Global data storage @@ -199,6 +205,18 @@ class Context: events = [e for e in self._in_progress[name] if e != ev] self._in_progress[name] = events + async def add_running_step(self, name: str) -> None: + async with self.lock: + self._currently_running_steps.add(name) + + async def remove_running_step(self, name: str) -> None: + async with self.lock: + self._currently_running_steps.remove(name) + + async def running_steps(self) -> List[str]: + async with self.lock: + return list(self._currently_running_steps) + async def get(self, key: str, default: Optional[Any] = Ellipsis) -> Any: """Get the value corresponding to `key` from the Context. diff --git a/llama-index-core/llama_index/core/workflow/events.py b/llama-index-core/llama_index/core/workflow/events.py index 2b02a88e1b..1a7e3990b2 100644 --- a/llama-index-core/llama_index/core/workflow/events.py +++ b/llama-index-core/llama_index/core/workflow/events.py @@ -124,6 +124,10 @@ class Event(BaseModel): def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: return self._data + def __bool__(self) -> bool: + """Make test `if event:` pass on Event instances.""" + return True + @model_serializer(mode="wrap") def custom_model_dump(self, handler: Any) -> Dict[str, Any]: data = handler(self) diff --git a/llama-index-core/llama_index/core/workflow/handler.py b/llama-index-core/llama_index/core/workflow/handler.py index 98ed3f4c7b..5ac18b1e71 100644 --- a/llama-index-core/llama_index/core/workflow/handler.py +++ b/llama-index-core/llama_index/core/workflow/handler.py @@ -2,8 +2,10 @@ import asyncio from typing import Any, AsyncGenerator, Optional from llama_index.core.workflow.context import Context -from llama_index.core.workflow.events import Event, StopEvent from llama_index.core.workflow.errors import WorkflowDone +from llama_index.core.workflow.events import Event, StopEvent + +from .utils import BUSY_WAIT_DELAY class WorkflowHandler(asyncio.Future): @@ -12,7 +14,7 @@ class WorkflowHandler(asyncio.Future): *args: Any, ctx: Optional[Context] = None, run_id: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.run_id = run_id @@ -79,7 +81,7 @@ class WorkflowHandler(asyncio.Future): we_done = True e = t.exception() - if type(e) != WorkflowDone: + if type(e) is not WorkflowDone: exception_raised = e if we_done: @@ -96,7 +98,14 @@ class WorkflowHandler(asyncio.Future): if not self.done(): self.set_result(self.ctx.get_result()) - else: # continue with running next step + else: + # Continue with running next step. Make sure we wait for the + # step function to return before proceeding. + in_progress = len(await self.ctx.running_steps()) + while in_progress: + await asyncio.sleep(BUSY_WAIT_DELAY) + in_progress = len(await self.ctx.running_steps()) + # notify unblocked task that we're ready to accept next event async with self.ctx._step_condition: self.ctx._step_condition.notify() diff --git a/llama-index-core/llama_index/core/workflow/utils.py b/llama-index-core/llama_index/core/workflow/utils.py index 34eb3a1106..1d8243b71e 100644 --- a/llama-index-core/llama_index/core/workflow/utils.py +++ b/llama-index-core/llama_index/core/workflow/utils.py @@ -1,14 +1,14 @@ import inspect from importlib import import_module from typing import ( - get_args, - get_origin, Any, + Callable, + Dict, List, Optional, Union, - Callable, - Dict, + get_args, + get_origin, get_type_hints, ) @@ -20,8 +20,10 @@ except ImportError: # pragma: no cover from llama_index.core.bridge.pydantic import BaseModel, ConfigDict -from .events import Event, EventType from .errors import WorkflowValidationError +from .events import Event, EventType + +BUSY_WAIT_DELAY = 0.01 class ServiceDefinition(BaseModel): diff --git a/llama-index-core/llama_index/core/workflow/workflow.py b/llama-index-core/llama_index/core/workflow/workflow.py index 5c9a3d86ee..bf4c322868 100644 --- a/llama-index-core/llama_index/core/workflow/workflow.py +++ b/llama-index-core/llama_index/core/workflow/workflow.py @@ -266,6 +266,7 @@ class Workflow(metaclass=WorkflowMeta): attempts = 0 while True: await ctx.mark_in_progress(name=name, ev=ev) + await ctx.add_running_step(name) try: new_ev = await instrumented_step(**kwargs) break # exit the retrying loop @@ -293,6 +294,8 @@ class Workflow(metaclass=WorkflowMeta): f"Step {name} produced an error, retry in {delay} seconds" ) await asyncio.sleep(delay) + finally: + await ctx.remove_running_step(name) else: try: diff --git a/llama-index-core/tests/workflow/test_stepwise.py b/llama-index-core/tests/workflow/test_stepwise.py new file mode 100644 index 0000000000..13e415944e --- /dev/null +++ b/llama-index-core/tests/workflow/test_stepwise.py @@ -0,0 +1,33 @@ +import asyncio + +import pytest +from llama_index.core.workflow.decorators import step +from llama_index.core.workflow.events import Event, StartEvent, StopEvent +from llama_index.core.workflow.workflow import Workflow + + +class IntermediateEvent(Event): + value: int + + +class StepWorkflow(Workflow): + @step + async def step1(self, ev: StartEvent) -> IntermediateEvent: + await asyncio.sleep(0.1) + return IntermediateEvent(value=21) + + @step + async def step2(self, ev: IntermediateEvent) -> StopEvent: + await asyncio.sleep(0.1) + return StopEvent(result=ev.value * 2) + + +@pytest.mark.asyncio() +async def test_simple_stepwise(): + workflow = StepWorkflow() + handler = workflow.run(stepwise=True) + while ev := await handler.run_step(): + handler.ctx.send_event(ev) # type: ignore + + result = await handler + assert result == 42 -- GitLab