From 9163067027ea8222e9fe5bffff9a2fac26b57686 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Sun, 7 Apr 2024 18:55:20 -0400 Subject: [PATCH] Coroutine-safe Spans (#12589) * add async task stack * just use asyncio.current_task * wip * delay asyncio Lock assignment in dispatcher * fix async unit tests --- .../llama_index/core/async_utils.py | 6 + .../core/instrumentation/dispatcher.py | 106 +++++++++++++++++- .../instrumentation/span_handlers/base.py | 3 +- .../tests/instrumentation/test_dispatcher.py | 14 ++- 4 files changed, 121 insertions(+), 8 deletions(-) diff --git a/llama-index-core/llama_index/core/async_utils.py b/llama-index-core/llama_index/core/async_utils.py index b07ada61cd..c07afb5412 100644 --- a/llama-index-core/llama_index/core/async_utils.py +++ b/llama-index-core/llama_index/core/async_utils.py @@ -3,6 +3,9 @@ import asyncio from itertools import zip_longest from typing import Any, Coroutine, Iterable, List, TypeVar +import llama_index.core.instrumentation as instrument + +dispatcher = instrument.get_dispatcher(__name__) def asyncio_module(show_progress: bool = False) -> Any: @@ -84,6 +87,7 @@ DEFAULT_NUM_WORKERS = 4 T = TypeVar("T") +@dispatcher.span async def run_jobs( jobs: List[Coroutine[Any, Any, T]], show_progress: bool = False, @@ -101,9 +105,11 @@ async def run_jobs( List[Any]: List of results. """ + parent_span_id = dispatcher.current_span_id asyncio_mod = get_asyncio_module(show_progress=show_progress) semaphore = asyncio.Semaphore(workers) + @dispatcher.async_span_with_parent_id(parent_id=parent_span_id) async def worker(job: Coroutine) -> Any: async with semaphore: return await job diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index c4b659aa06..7646c77795 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -13,9 +13,13 @@ from llama_index.core.instrumentation.span_handlers import ( ) from llama_index.core.instrumentation.events.base import BaseEvent from llama_index.core.instrumentation.events.span import SpanDropEvent +from contextvars import ContextVar import wrapt +span_ctx = ContextVar("span_ctx", default={}) + + class EventDispatcher(Protocol): def __call__(self, event: BaseEvent) -> None: ... @@ -50,7 +54,7 @@ class Dispatcher(BaseModel): current_span_id: Optional[str] = Field( default=None, description="Id of current span." ) - _asyncio_lock: asyncio.Lock = PrivateAttr() + _asyncio_lock: Optional[asyncio.Lock] = PrivateAttr() def __init__( self, @@ -62,7 +66,7 @@ class Dispatcher(BaseModel): root_name: str = "root", propagate: bool = True, ): - self._asyncio_lock = asyncio.Lock() + self._asyncio_lock = None super().__init__( name=name, event_handlers=event_handlers, @@ -73,6 +77,12 @@ class Dispatcher(BaseModel): propagate=propagate, ) + @property + def asyncio_lock(self) -> asyncio.Lock: + if self._asyncio_lock is None: + self._asyncio_lock = asyncio.Lock() + return self._asyncio_lock + @property def parent(self) -> "Dispatcher": return self.manager.dispatchers[self.parent_name] @@ -107,6 +117,7 @@ class Dispatcher(BaseModel): id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, + parent_id: Optional[str] = None, **kwargs: Any, ) -> None: """Send notice to handlers that a span with id_ has started.""" @@ -117,6 +128,7 @@ class Dispatcher(BaseModel): id_=id_, bound_args=bound_args, instance=instance, + parent_id=parent_id, **kwargs, ) if not c.propagate: @@ -199,6 +211,67 @@ class Dispatcher(BaseModel): finally: del dispatch_event + def async_span_with_parent_id(self, parent_id: str): + """This decorator should be used to span an async function nested in an outer span. + + Primary example: llama_index.core.async_utils.run_jobs + + Args: + parent_id (str): The span_id of the outer span. + """ + + def outer(func): + @wrapt.decorator + async def async_wrapper(func, instance, args, kwargs): + bound_args = inspect.signature(func).bind(*args, **kwargs) + id_ = f"{func.__qualname__}-{uuid.uuid4()}" + async with self.asyncio_lock: + self.current_span_id = id_ + async with self.root.asyncio_lock: + self.root.current_span_id = id_ + + current_task = asyncio.current_task() + current_task_name = current_task.get_name() + span_ctx_dict = span_ctx.get().copy() + if current_task_name not in span_ctx_dict: + span_ctx_dict[current_task_name] = [id_] + else: + span_ctx_dict[current_task_name].append(id_) + span_ctx.set(span_ctx_dict) + + self.span_enter( + id_=id_, + bound_args=bound_args, + instance=instance, + parent_id=parent_id, + ) + try: + result = await func(*args, **kwargs) + except BaseException as e: + self.event(SpanDropEvent(span_id=id_, err_str=str(e))) + self.span_drop( + id_=id_, bound_args=bound_args, instance=instance, err=e + ) + raise + else: + self.span_exit( + id_=id_, bound_args=bound_args, instance=instance, result=result + ) + return result + finally: + # clean up + current_task = asyncio.current_task() + current_task_name = current_task.get_name() + span_ctx_dict = span_ctx.get().copy() + span_ctx_dict[current_task_name].pop() + if len(span_ctx_dict[current_task_name]) == 0: + del span_ctx_dict[current_task_name] + span_ctx.set(span_ctx_dict) + + return async_wrapper(func) + + return outer + def span(self, func): @wrapt.decorator def wrapper(func, instance, args, kwargs): @@ -223,12 +296,26 @@ class Dispatcher(BaseModel): async def async_wrapper(func, instance, args, kwargs): bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" - async with self._asyncio_lock: + async with self.asyncio_lock: self.current_span_id = id_ - async with self.root._asyncio_lock: + async with self.root.asyncio_lock: self.root.current_span_id = id_ - self.span_enter(id_=id_, bound_args=bound_args, instance=instance) + # get parent_id + current_task = asyncio.current_task() + current_task_name = current_task.get_name() + span_ctx_dict = span_ctx.get().copy() + if current_task_name not in span_ctx_dict: + parent_id = None + span_ctx_dict[current_task_name] = [id_] + else: + parent_id = span_ctx_dict[current_task_name][-1] + span_ctx_dict[current_task_name].append(id_) + span_ctx.set(span_ctx_dict) + + self.span_enter( + id_=id_, bound_args=bound_args, instance=instance, parent_id=parent_id + ) try: result = await func(*args, **kwargs) except BaseException as e: @@ -240,6 +327,15 @@ class Dispatcher(BaseModel): id_=id_, bound_args=bound_args, instance=instance, result=result ) return result + finally: + # clean up + current_task = asyncio.current_task() + current_task_name = current_task.get_name() + span_ctx_dict = span_ctx.get().copy() + span_ctx_dict[current_task_name].pop() + if len(span_ctx_dict[current_task_name]) == 0: + del span_ctx_dict[current_task_name] + span_ctx.set(span_ctx_dict) if inspect.iscoroutinefunction(func): return async_wrapper(func) diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py index 725cfaa41f..82b5b487c8 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py @@ -34,6 +34,7 @@ class BaseSpanHandler(BaseModel, Generic[T]): id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, + parent_id: Optional[str] = None, **kwargs: Any, ) -> None: """Logic for entering a span.""" @@ -45,7 +46,7 @@ class BaseSpanHandler(BaseModel, Generic[T]): id_=id_, bound_args=bound_args, instance=instance, - parent_span_id=self.current_span_id, + parent_span_id=parent_id or self.current_span_id, ) if span: self.open_spans[id_] = span diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py index f7269f11b3..a59b917441 100644 --- a/llama-index-core/tests/instrumentation/test_dispatcher.py +++ b/llama-index-core/tests/instrumentation/test_dispatcher.py @@ -265,7 +265,12 @@ async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_ mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () - assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": None} + assert kwargs == { + "id_": span_id, + "bound_args": bound_args, + "instance": None, + "parent_id": None, + } # span_exit args, kwargs = mock_span_exit.call_args @@ -299,7 +304,12 @@ async def test_dispatcher_async_span_args_with_instance( mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () - assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": instance} + assert kwargs == { + "id_": span_id, + "bound_args": bound_args, + "instance": instance, + "parent_id": None, + } # span_exit args, kwargs = mock_span_exit.call_args -- GitLab