Skip to content
Snippets Groups Projects
Unverified Commit dd593093 authored by Logan's avatar Logan Committed by GitHub
Browse files

safer workflow cancel + fix restored context bug (#17938)

parent 1d4c9679
No related branches found
No related tags found
No related merge requests found
...@@ -496,6 +496,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -496,6 +496,7 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
state_prompt: Optional[Union[str, BasePromptTemplate]] = None, state_prompt: Optional[Union[str, BasePromptTemplate]] = None,
initial_state: Optional[dict] = None, initial_state: Optional[dict] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
verbose: bool = False,
) -> "AgentWorkflow": ) -> "AgentWorkflow":
"""Initializes an AgentWorkflow from a list of tools or functions. """Initializes an AgentWorkflow from a list of tools or functions.
...@@ -528,4 +529,5 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta): ...@@ -528,4 +529,5 @@ class AgentWorkflow(Workflow, PromptMixin, metaclass=AgentWorkflowMeta):
state_prompt=state_prompt, state_prompt=state_prompt,
initial_state=initial_state, initial_state=initial_state,
timeout=timeout, timeout=timeout,
verbose=verbose,
) )
...@@ -369,8 +369,11 @@ class Workflow(metaclass=WorkflowMeta): ...@@ -369,8 +369,11 @@ class Workflow(metaclass=WorkflowMeta):
# add dedicated cancel task # add dedicated cancel task
async def _cancel_workflow_task() -> None: async def _cancel_workflow_task() -> None:
await ctx._cancel_flag.wait() try:
raise WorkflowCancelledByUser await ctx._cancel_flag.wait()
raise WorkflowCancelledByUser
except asyncio.CancelledError:
return
ctx._tasks.add( ctx._tasks.add(
asyncio.create_task( asyncio.create_task(
...@@ -458,11 +461,6 @@ class Workflow(metaclass=WorkflowMeta): ...@@ -458,11 +461,6 @@ class Workflow(metaclass=WorkflowMeta):
# the context is now running # the context is now running
ctx.is_running = True ctx.is_running = True
else:
# resend in-progress events if already running
for name, evs in ctx._in_progress.items():
for ev in evs:
ctx.send_event(ev, step=name)
done, unfinished = await asyncio.wait( done, unfinished = await asyncio.wait(
ctx._tasks, ctx._tasks,
...@@ -485,7 +483,14 @@ class Workflow(metaclass=WorkflowMeta): ...@@ -485,7 +483,14 @@ class Workflow(metaclass=WorkflowMeta):
t.cancel() t.cancel()
# wait for cancelled tasks to cleanup # wait for cancelled tasks to cleanup
await asyncio.gather(*unfinished, return_exceptions=True) # prevents any tasks from being stuck
try:
await asyncio.wait_for(
asyncio.gather(*unfinished, return_exceptions=True),
timeout=0.5,
)
except asyncio.TimeoutError:
logger.warning("Some tasks did not clean up within timeout")
# the context is no longer running # the context is no longer running
ctx.is_running = False ctx.is_running = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment