diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index 0e964bc38af84c6d13d59f2105dd72da4fad0bd6..1c7603d98336c3a233df648fb0883cb785cfc00d 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -3,7 +3,18 @@ import json import uuid import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, +) from .context_serializers import BaseSerializer, JsonSerializer from .decorators import StepConfig @@ -58,8 +69,9 @@ class Context: # an existing context. self._in_progress: Dict[str, List[Event]] = defaultdict(list) # Keep track of the steps currently running. This is only valid when a - # workflow is running and won't be serialized. - self._currently_running_steps: Set[str] = set() + # workflow is running and won't be serialized. Note that a single step + # might have multiple workers, so we keep a counter. + self._currently_running_steps: DefaultDict[str, int] = defaultdict(int) # Streaming machinery self._streaming_queue: asyncio.Queue = asyncio.Queue() # Global data storage @@ -207,11 +219,13 @@ class Context: async def add_running_step(self, name: str) -> None: async with self.lock: - self._currently_running_steps.add(name) + self._currently_running_steps[name] += 1 async def remove_running_step(self, name: str) -> None: async with self.lock: - self._currently_running_steps.remove(name) + self._currently_running_steps[name] -= 1 + if self._currently_running_steps[name] == 0: + del self._currently_running_steps[name] async def running_steps(self) -> List[str]: async with self.lock: