Skip to content
Snippets Groups Projects
Unverified Commit 07101b31 authored by Massimiliano Pippi's avatar Massimiliano Pippi Committed by GitHub
Browse files

fix: stepwise execution breaks when steps do async work (#17914)

parent 4c14e54b
No related branches found
No related tags found
No related merge requests found
...@@ -428,18 +428,20 @@ final_result = await handler ...@@ -428,18 +428,20 @@ final_result = await handler
Workflows have built-in utilities for stepwise execution, allowing you to control execution and debug state as things progress. Workflows have built-in utilities for stepwise execution, allowing you to control execution and debug state as things progress.
```python ```python
w = JokeFlow(...) # Create a workflow, same as usual
w = JokeFlow()
# Kick off the workflow # Get the handler. Passing `stepwise=True` will block execution, waiting for manual intervention
handler = w.run(topic="Pirates") handler = workflow.run(stepwise=True)
# Each time we call `run_step`, the workflow will advance and return the Event
# Iterate until done # that was produced in the last step. This event needs to be manually propagated
async for _ in handler: # for the workflow to keep going (we assign it to `ev` with the := operator).
# inspect context while ev := await handler.run_step():
# val = await handler.ctx.get("key") # If we're here, it means there's an event we need to propagate,
continue # let's do it with `send_event`
handler.ctx.send_event(ev)
# Get the final result
# If we're here, it means the workflow execution completed, and
# we can now access the final result.
result = await handler result = await handler
``` ```
......
...@@ -53,7 +53,13 @@ class Context: ...@@ -53,7 +53,13 @@ class Context:
) )
self._accepted_events: List[Tuple[str, str]] = [] self._accepted_events: List[Tuple[str, str]] = []
self._retval: Any = None self._retval: Any = None
# Map the step names that were executed to a list of events they received.
# This will be serialized, and is needed to resume a Workflow run passing
# an existing context.
self._in_progress: Dict[str, List[Event]] = defaultdict(list) self._in_progress: Dict[str, List[Event]] = defaultdict(list)
# Keep track of the steps currently running. This is only valid when a
# workflow is running and won't be serialized.
self._currently_running_steps: Set[str] = set()
# Streaming machinery # Streaming machinery
self._streaming_queue: asyncio.Queue = asyncio.Queue() self._streaming_queue: asyncio.Queue = asyncio.Queue()
# Global data storage # Global data storage
...@@ -199,6 +205,18 @@ class Context: ...@@ -199,6 +205,18 @@ class Context:
events = [e for e in self._in_progress[name] if e != ev] events = [e for e in self._in_progress[name] if e != ev]
self._in_progress[name] = events self._in_progress[name] = events
async def add_running_step(self, name: str) -> None:
async with self.lock:
self._currently_running_steps.add(name)
async def remove_running_step(self, name: str) -> None:
async with self.lock:
self._currently_running_steps.remove(name)
async def running_steps(self) -> List[str]:
async with self.lock:
return list(self._currently_running_steps)
async def get(self, key: str, default: Optional[Any] = Ellipsis) -> Any: async def get(self, key: str, default: Optional[Any] = Ellipsis) -> Any:
"""Get the value corresponding to `key` from the Context. """Get the value corresponding to `key` from the Context.
......
...@@ -124,6 +124,10 @@ class Event(BaseModel): ...@@ -124,6 +124,10 @@ class Event(BaseModel):
def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return self._data return self._data
def __bool__(self) -> bool:
"""Make test `if event:` pass on Event instances."""
return True
@model_serializer(mode="wrap") @model_serializer(mode="wrap")
def custom_model_dump(self, handler: Any) -> Dict[str, Any]: def custom_model_dump(self, handler: Any) -> Dict[str, Any]:
data = handler(self) data = handler(self)
......
...@@ -2,8 +2,10 @@ import asyncio ...@@ -2,8 +2,10 @@ import asyncio
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional
from llama_index.core.workflow.context import Context from llama_index.core.workflow.context import Context
from llama_index.core.workflow.events import Event, StopEvent
from llama_index.core.workflow.errors import WorkflowDone from llama_index.core.workflow.errors import WorkflowDone
from llama_index.core.workflow.events import Event, StopEvent
from .utils import BUSY_WAIT_DELAY
class WorkflowHandler(asyncio.Future): class WorkflowHandler(asyncio.Future):
...@@ -12,7 +14,7 @@ class WorkflowHandler(asyncio.Future): ...@@ -12,7 +14,7 @@ class WorkflowHandler(asyncio.Future):
*args: Any, *args: Any,
ctx: Optional[Context] = None, ctx: Optional[Context] = None,
run_id: Optional[str] = None, run_id: Optional[str] = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.run_id = run_id self.run_id = run_id
...@@ -79,7 +81,7 @@ class WorkflowHandler(asyncio.Future): ...@@ -79,7 +81,7 @@ class WorkflowHandler(asyncio.Future):
we_done = True we_done = True
e = t.exception() e = t.exception()
if type(e) != WorkflowDone: if type(e) is not WorkflowDone:
exception_raised = e exception_raised = e
if we_done: if we_done:
...@@ -96,7 +98,14 @@ class WorkflowHandler(asyncio.Future): ...@@ -96,7 +98,14 @@ class WorkflowHandler(asyncio.Future):
if not self.done(): if not self.done():
self.set_result(self.ctx.get_result()) self.set_result(self.ctx.get_result())
else: # continue with running next step else:
# Continue with running next step. Make sure we wait for the
# step function to return before proceeding.
in_progress = len(await self.ctx.running_steps())
while in_progress:
await asyncio.sleep(BUSY_WAIT_DELAY)
in_progress = len(await self.ctx.running_steps())
# notify unblocked task that we're ready to accept next event # notify unblocked task that we're ready to accept next event
async with self.ctx._step_condition: async with self.ctx._step_condition:
self.ctx._step_condition.notify() self.ctx._step_condition.notify()
......
import inspect import inspect
from importlib import import_module from importlib import import_module
from typing import ( from typing import (
get_args,
get_origin,
Any, Any,
Callable,
Dict,
List, List,
Optional, Optional,
Union, Union,
Callable, get_args,
Dict, get_origin,
get_type_hints, get_type_hints,
) )
...@@ -20,8 +20,10 @@ except ImportError: # pragma: no cover ...@@ -20,8 +20,10 @@ except ImportError: # pragma: no cover
from llama_index.core.bridge.pydantic import BaseModel, ConfigDict from llama_index.core.bridge.pydantic import BaseModel, ConfigDict
from .events import Event, EventType
from .errors import WorkflowValidationError from .errors import WorkflowValidationError
from .events import Event, EventType
BUSY_WAIT_DELAY = 0.01
class ServiceDefinition(BaseModel): class ServiceDefinition(BaseModel):
......
...@@ -266,6 +266,7 @@ class Workflow(metaclass=WorkflowMeta): ...@@ -266,6 +266,7 @@ class Workflow(metaclass=WorkflowMeta):
attempts = 0 attempts = 0
while True: while True:
await ctx.mark_in_progress(name=name, ev=ev) await ctx.mark_in_progress(name=name, ev=ev)
await ctx.add_running_step(name)
try: try:
new_ev = await instrumented_step(**kwargs) new_ev = await instrumented_step(**kwargs)
break # exit the retrying loop break # exit the retrying loop
...@@ -293,6 +294,8 @@ class Workflow(metaclass=WorkflowMeta): ...@@ -293,6 +294,8 @@ class Workflow(metaclass=WorkflowMeta):
f"Step {name} produced an error, retry in {delay} seconds" f"Step {name} produced an error, retry in {delay} seconds"
) )
await asyncio.sleep(delay) await asyncio.sleep(delay)
finally:
await ctx.remove_running_step(name)
else: else:
try: try:
......
import asyncio
import pytest
from llama_index.core.workflow.decorators import step
from llama_index.core.workflow.events import Event, StartEvent, StopEvent
from llama_index.core.workflow.workflow import Workflow
class IntermediateEvent(Event):
value: int
class StepWorkflow(Workflow):
@step
async def step1(self, ev: StartEvent) -> IntermediateEvent:
await asyncio.sleep(0.1)
return IntermediateEvent(value=21)
@step
async def step2(self, ev: IntermediateEvent) -> StopEvent:
await asyncio.sleep(0.1)
return StopEvent(result=ev.value * 2)
@pytest.mark.asyncio()
async def test_simple_stepwise():
workflow = StepWorkflow()
handler = workflow.run(stepwise=True)
while ev := await handler.run_step():
handler.ctx.send_event(ev) # type: ignore
result = await handler
assert result == 42
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment