Skip to content
Snippets Groups Projects
Unverified Commit 23fe88cd authored by Logan's avatar Logan Committed by GitHub
Browse files

Ensure resuming a workflow actually works (#18023)

parent a804b39d
Branches
Tags
No related merge requests found
......@@ -362,8 +362,6 @@ class Workflow(metaclass=WorkflowMeta):
warnings.warn(
f"Step function {name} returned {type(new_ev).__name__} instead of an Event instance."
)
elif isinstance(new_ev, InputRequiredEvent):
ctx.write_event_to_stream(new_ev)
else:
if stepwise:
async with ctx._step_condition:
......@@ -393,7 +391,13 @@ class Workflow(metaclass=WorkflowMeta):
input_ev=ev,
output_ev=new_ev,
)
ctx.send_event(new_ev)
# InputRequiredEvent's are special case and need to be written to the stream
# this way, the user can access and respond to the event
if isinstance(new_ev, InputRequiredEvent):
ctx.write_event_to_stream(new_ev)
else:
ctx.send_event(new_ev)
for _ in range(step_config.num_workers):
ctx._tasks.add(
......
......@@ -618,17 +618,22 @@ async def test_workflow_context_to_dict(workflow):
assert new_ctx._queues["start_step"].get_nowait().name == "test"
@pytest.mark.asyncio()
async def test_human_in_the_loop():
class HumanInTheLoopWorkflow(Workflow):
@step
async def step1(self, ev: StartEvent) -> InputRequiredEvent:
return InputRequiredEvent(prefix="Enter a number: ")
class HumanInTheLoopWorkflow(Workflow):
@step
async def step1(self, ctx: Context, ev: StartEvent) -> InputRequiredEvent:
cur_runs = await ctx.get("step1_runs", default=0)
await ctx.set("step1_runs", cur_runs + 1)
return InputRequiredEvent(prefix="Enter a number: ")
@step
async def step2(self, ev: HumanResponseEvent) -> StopEvent:
return StopEvent(result=ev.response)
@step
async def step2(self, ctx: Context, ev: HumanResponseEvent) -> StopEvent:
cur_runs = await ctx.get("step2_runs", default=0)
await ctx.set("step2_runs", cur_runs + 1)
return StopEvent(result=ev.response)
@pytest.mark.asyncio()
async def test_human_in_the_loop():
workflow = HumanInTheLoopWorkflow(timeout=1)
# workflow should raise a timeout error because hitl only works with streaming
......@@ -653,6 +658,34 @@ async def test_human_in_the_loop():
assert final_result == "42"
@pytest.mark.asyncio()
async def test_human_in_the_loop_with_resume():
# workflow should work with streaming
workflow = HumanInTheLoopWorkflow()
handler: WorkflowHandler = workflow.run()
assert handler.ctx
ctx_dict = None
async for event in handler.stream_events():
if isinstance(event, InputRequiredEvent):
ctx_dict = handler.ctx.to_dict()
await handler.cancel_run()
break
new_handler = workflow.run(ctx=Context.from_dict(workflow, ctx_dict))
new_handler.ctx.send_event(HumanResponseEvent(response="42"))
final_result = await new_handler
assert final_result == "42"
# ensure the workflow ran each step once
step1_runs = await new_handler.ctx.get("step1_runs")
step2_runs = await new_handler.ctx.get("step2_runs")
assert step1_runs == 1
assert step2_runs == 1
class DummyWorkflowForConcurrentRunsTest(Workflow):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment