From 82fd82e82f6e83843c1dc3f29279ef586b7ba6e2 Mon Sep 17 00:00:00 2001
From: Massimiliano Pippi <mpippi@gmail.com>
Date: Thu, 27 Feb 2025 20:14:24 +0100
Subject: [PATCH] fix: take step workers into account when running a workflow
 step-wise (#17942)

---
 .../llama_index/core/workflow/context.py      | 24 +++++++++++++++----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py
index 0e964bc38a..1c7603d983 100644
--- a/llama-index-core/llama_index/core/workflow/context.py
+++ b/llama-index-core/llama_index/core/workflow/context.py
@@ -3,7 +3,18 @@ import json
 import uuid
 import warnings
 from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    DefaultDict,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+)
 
 from .context_serializers import BaseSerializer, JsonSerializer
 from .decorators import StepConfig
@@ -58,8 +69,9 @@ class Context:
         # an existing context.
         self._in_progress: Dict[str, List[Event]] = defaultdict(list)
         # Keep track of the steps currently running. This is only valid when a
-        # workflow is running and won't be serialized.
-        self._currently_running_steps: Set[str] = set()
+        # workflow is running and won't be serialized. Note that a single step
+        # might have multiple workers, so we keep a counter.
+        self._currently_running_steps: DefaultDict[str, int] = defaultdict(int)
         # Streaming machinery
         self._streaming_queue: asyncio.Queue = asyncio.Queue()
         # Global data storage
@@ -207,11 +219,13 @@ class Context:
 
     async def add_running_step(self, name: str) -> None:
         async with self.lock:
-            self._currently_running_steps.add(name)
+            self._currently_running_steps[name] += 1
 
     async def remove_running_step(self, name: str) -> None:
         async with self.lock:
-            self._currently_running_steps.remove(name)
+            self._currently_running_steps[name] -= 1
+            if self._currently_running_steps[name] == 0:
+                del self._currently_running_steps[name]
 
     async def running_steps(self) -> List[str]:
         async with self.lock:
-- 
GitLab