From 82fd82e82f6e83843c1dc3f29279ef586b7ba6e2 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi <mpippi@gmail.com> Date: Thu, 27 Feb 2025 20:14:24 +0100 Subject: [PATCH] fix: take step workers into account when running a workflow step-wise (#17942) --- .../llama_index/core/workflow/context.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index 0e964bc38a..1c7603d983 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -3,7 +3,18 @@ import json import uuid import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, +) from .context_serializers import BaseSerializer, JsonSerializer from .decorators import StepConfig @@ -58,8 +69,9 @@ class Context: # an existing context. self._in_progress: Dict[str, List[Event]] = defaultdict(list) # Keep track of the steps currently running. This is only valid when a - # workflow is running and won't be serialized. - self._currently_running_steps: Set[str] = set() + # workflow is running and won't be serialized. Note that a single step + # might have multiple workers, so we keep a counter. + self._currently_running_steps: DefaultDict[str, int] = defaultdict(int) # Streaming machinery self._streaming_queue: asyncio.Queue = asyncio.Queue() # Global data storage @@ -207,11 +219,13 @@ class Context: async def add_running_step(self, name: str) -> None: async with self.lock: - self._currently_running_steps.add(name) + self._currently_running_steps[name] += 1 async def remove_running_step(self, name: str) -> None: async with self.lock: - self._currently_running_steps.remove(name) + self._currently_running_steps[name] -= 1 + if self._currently_running_steps[name] == 0: + del self._currently_running_steps[name] async def running_steps(self) -> List[str]: async with self.lock: -- GitLab