From 9163067027ea8222e9fe5bffff9a2fac26b57686 Mon Sep 17 00:00:00 2001
From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com>
Date: Sun, 7 Apr 2024 18:55:20 -0400
Subject: [PATCH] Coroutine-safe Spans (#12589)

* add async task stack

* just use asyncio.current_task

* wip

* delay asyncio Lock assignment in dispatcher

* fix async unit tests
---
 .../llama_index/core/async_utils.py           |   6 +
 .../core/instrumentation/dispatcher.py        | 106 +++++++++++++++++-
 .../instrumentation/span_handlers/base.py     |   3 +-
 .../tests/instrumentation/test_dispatcher.py  |  14 ++-
 4 files changed, 121 insertions(+), 8 deletions(-)

diff --git a/llama-index-core/llama_index/core/async_utils.py b/llama-index-core/llama_index/core/async_utils.py
index b07ada61cd..c07afb5412 100644
--- a/llama-index-core/llama_index/core/async_utils.py
+++ b/llama-index-core/llama_index/core/async_utils.py
@@ -3,6 +3,9 @@
 import asyncio
 from itertools import zip_longest
 from typing import Any, Coroutine, Iterable, List, TypeVar
+import llama_index.core.instrumentation as instrument
+
+dispatcher = instrument.get_dispatcher(__name__)
 
 
 def asyncio_module(show_progress: bool = False) -> Any:
@@ -84,6 +87,7 @@ DEFAULT_NUM_WORKERS = 4
 T = TypeVar("T")
 
 
+@dispatcher.span
 async def run_jobs(
     jobs: List[Coroutine[Any, Any, T]],
     show_progress: bool = False,
@@ -101,9 +105,11 @@ async def run_jobs(
         List[Any]:
             List of results.
     """
+    parent_span_id = dispatcher.current_span_id
     asyncio_mod = get_asyncio_module(show_progress=show_progress)
     semaphore = asyncio.Semaphore(workers)
 
+    @dispatcher.async_span_with_parent_id(parent_id=parent_span_id)
     async def worker(job: Coroutine) -> Any:
         async with semaphore:
             return await job
diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
index c4b659aa06..7646c77795 100644
--- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py
+++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
@@ -13,9 +13,13 @@ from llama_index.core.instrumentation.span_handlers import (
 )
 from llama_index.core.instrumentation.events.base import BaseEvent
 from llama_index.core.instrumentation.events.span import SpanDropEvent
+from contextvars import ContextVar
 import wrapt
 
 
+span_ctx = ContextVar("span_ctx", default={})
+
+
 class EventDispatcher(Protocol):
     def __call__(self, event: BaseEvent) -> None:
         ...
@@ -50,7 +54,7 @@ class Dispatcher(BaseModel):
     current_span_id: Optional[str] = Field(
         default=None, description="Id of current span."
     )
-    _asyncio_lock: asyncio.Lock = PrivateAttr()
+    _asyncio_lock: Optional[asyncio.Lock] = PrivateAttr()
 
     def __init__(
         self,
@@ -62,7 +66,7 @@ class Dispatcher(BaseModel):
         root_name: str = "root",
         propagate: bool = True,
     ):
-        self._asyncio_lock = asyncio.Lock()
+        self._asyncio_lock = None
         super().__init__(
             name=name,
             event_handlers=event_handlers,
@@ -73,6 +77,12 @@ class Dispatcher(BaseModel):
             propagate=propagate,
         )
 
+    @property
+    def asyncio_lock(self) -> asyncio.Lock:
+        if self._asyncio_lock is None:
+            self._asyncio_lock = asyncio.Lock()
+        return self._asyncio_lock
+
     @property
     def parent(self) -> "Dispatcher":
         return self.manager.dispatchers[self.parent_name]
@@ -107,6 +117,7 @@ class Dispatcher(BaseModel):
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
+        parent_id: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
         """Send notice to handlers that a span with id_ has started."""
@@ -117,6 +128,7 @@ class Dispatcher(BaseModel):
                     id_=id_,
                     bound_args=bound_args,
                     instance=instance,
+                    parent_id=parent_id,
                     **kwargs,
                 )
             if not c.propagate:
@@ -199,6 +211,67 @@ class Dispatcher(BaseModel):
         finally:
             del dispatch_event
 
+    def async_span_with_parent_id(self, parent_id: str):
+        """This decorator should be used to span an async function nested in an outer span.
+
+        Primary example: llama_index.core.async_utils.run_jobs
+
+        Args:
+            parent_id (str): The span_id of the outer span.
+        """
+
+        def outer(func):
+            @wrapt.decorator
+            async def async_wrapper(func, instance, args, kwargs):
+                bound_args = inspect.signature(func).bind(*args, **kwargs)
+                id_ = f"{func.__qualname__}-{uuid.uuid4()}"
+                async with self.asyncio_lock:
+                    self.current_span_id = id_
+                async with self.root.asyncio_lock:
+                    self.root.current_span_id = id_
+
+                current_task = asyncio.current_task()
+                current_task_name = current_task.get_name()
+                span_ctx_dict = span_ctx.get().copy()
+                if current_task_name not in span_ctx_dict:
+                    span_ctx_dict[current_task_name] = [id_]
+                else:
+                    span_ctx_dict[current_task_name].append(id_)
+                span_ctx.set(span_ctx_dict)
+
+                self.span_enter(
+                    id_=id_,
+                    bound_args=bound_args,
+                    instance=instance,
+                    parent_id=parent_id,
+                )
+                try:
+                    result = await func(*args, **kwargs)
+                except BaseException as e:
+                    self.event(SpanDropEvent(span_id=id_, err_str=str(e)))
+                    self.span_drop(
+                        id_=id_, bound_args=bound_args, instance=instance, err=e
+                    )
+                    raise
+                else:
+                    self.span_exit(
+                        id_=id_, bound_args=bound_args, instance=instance, result=result
+                    )
+                    return result
+                finally:
+                    # clean up
+                    current_task = asyncio.current_task()
+                    current_task_name = current_task.get_name()
+                    span_ctx_dict = span_ctx.get().copy()
+                    span_ctx_dict[current_task_name].pop()
+                    if len(span_ctx_dict[current_task_name]) == 0:
+                        del span_ctx_dict[current_task_name]
+                    span_ctx.set(span_ctx_dict)
+
+            return async_wrapper(func)
+
+        return outer
+
     def span(self, func):
         @wrapt.decorator
         def wrapper(func, instance, args, kwargs):
@@ -223,12 +296,26 @@ class Dispatcher(BaseModel):
         async def async_wrapper(func, instance, args, kwargs):
             bound_args = inspect.signature(func).bind(*args, **kwargs)
             id_ = f"{func.__qualname__}-{uuid.uuid4()}"
-            async with self._asyncio_lock:
+            async with self.asyncio_lock:
                 self.current_span_id = id_
-            async with self.root._asyncio_lock:
+            async with self.root.asyncio_lock:
                 self.root.current_span_id = id_
 
-            self.span_enter(id_=id_, bound_args=bound_args, instance=instance)
+            # get parent_id
+            current_task = asyncio.current_task()
+            current_task_name = current_task.get_name()
+            span_ctx_dict = span_ctx.get().copy()
+            if current_task_name not in span_ctx_dict:
+                parent_id = None
+                span_ctx_dict[current_task_name] = [id_]
+            else:
+                parent_id = span_ctx_dict[current_task_name][-1]
+                span_ctx_dict[current_task_name].append(id_)
+            span_ctx.set(span_ctx_dict)
+
+            self.span_enter(
+                id_=id_, bound_args=bound_args, instance=instance, parent_id=parent_id
+            )
             try:
                 result = await func(*args, **kwargs)
             except BaseException as e:
@@ -240,6 +327,15 @@ class Dispatcher(BaseModel):
                     id_=id_, bound_args=bound_args, instance=instance, result=result
                 )
                 return result
+            finally:
+                # clean up
+                current_task = asyncio.current_task()
+                current_task_name = current_task.get_name()
+                span_ctx_dict = span_ctx.get().copy()
+                span_ctx_dict[current_task_name].pop()
+                if len(span_ctx_dict[current_task_name]) == 0:
+                    del span_ctx_dict[current_task_name]
+                span_ctx.set(span_ctx_dict)
 
         if inspect.iscoroutinefunction(func):
             return async_wrapper(func)
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
index 725cfaa41f..82b5b487c8 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
@@ -34,6 +34,7 @@ class BaseSpanHandler(BaseModel, Generic[T]):
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
+        parent_id: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
         """Logic for entering a span."""
@@ -45,7 +46,7 @@ class BaseSpanHandler(BaseModel, Generic[T]):
                 id_=id_,
                 bound_args=bound_args,
                 instance=instance,
-                parent_span_id=self.current_span_id,
+                parent_span_id=parent_id or self.current_span_id,
             )
             if span:
                 self.open_spans[id_] = span
diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py
index f7269f11b3..a59b917441 100644
--- a/llama-index-core/tests/instrumentation/test_dispatcher.py
+++ b/llama-index-core/tests/instrumentation/test_dispatcher.py
@@ -265,7 +265,12 @@ async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
-    assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": None}
+    assert kwargs == {
+        "id_": span_id,
+        "bound_args": bound_args,
+        "instance": None,
+        "parent_id": None,
+    }
 
     # span_exit
     args, kwargs = mock_span_exit.call_args
@@ -299,7 +304,12 @@ async def test_dispatcher_async_span_args_with_instance(
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
-    assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": instance}
+    assert kwargs == {
+        "id_": span_id,
+        "bound_args": bound_args,
+        "instance": instance,
+        "parent_id": None,
+    }
 
     # span_exit
     args, kwargs = mock_span_exit.call_args
-- 
GitLab