diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 7646c77795778fb1810afb89e47e9ace6ac1a43f..5a2d9e21881865d5769c81f8b0954d989258c86b 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -1,8 +1,9 @@ from typing import Any, List, Optional, Dict, Protocol from functools import partial -from contextlib import contextmanager +from collections import defaultdict import asyncio import inspect +import threading import uuid from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr from llama_index.core.instrumentation.events import BaseEvent @@ -17,7 +18,11 @@ from contextvars import ContextVar import wrapt -span_ctx = ContextVar("span_ctx", default={}) +# ContextVar's for managing active spans +span_ctx_var = ContextVar( + "span_ctx_var", default=defaultdict(dict) +) # per thread >> async-task +DEFAULT_SYNC_KEY = "sync_tasks" class EventDispatcher(Protocol): @@ -25,14 +30,21 @@ class EventDispatcher(Protocol): ... -class EventContext(BaseModel): - span_id: str = Field(default="") - +class Dispatcher(BaseModel): + """Dispatcher class. -event_context = EventContext() + Responsible for dispatching BaseEvent (and its subclasses) as well as + sending signals to enter/exit/drop a BaseSpan. It does so by sending + event and span signals to its attached BaseEventHandler as well as + BaseSpanHandler. + Concurrency: + - Dispatcher is async-task and thread safe in the sense that + spans of async coros will maintain its hieararchy or trace-trees and + spans which emanate from various threads will also maintain its + hierarchy. + """ -class Dispatcher(BaseModel): name: str = Field(default_factory=str, description="Name of dispatcher") event_handlers: List[BaseEventHandler] = Field( default=[], description="List of attached handlers" @@ -51,10 +63,12 @@ class Dispatcher(BaseModel): default=True, description="Whether to propagate the event to parent dispatchers and their handlers", ) - current_span_id: Optional[str] = Field( - default=None, description="Id of current span." + current_span_ids: Optional[Dict[Any, str]] = Field( + default_factory=dict, + description="Id of current enclosing span. Used for creating `dispatch_event` partials.", ) _asyncio_lock: Optional[asyncio.Lock] = PrivateAttr() + _lock: Optional[threading.Lock] = PrivateAttr() def __init__( self, @@ -67,6 +81,7 @@ class Dispatcher(BaseModel): propagate: bool = True, ): self._asyncio_lock = None + self._lock = None super().__init__( name=name, event_handlers=event_handlers, @@ -77,12 +92,29 @@ class Dispatcher(BaseModel): propagate=propagate, ) + @property + def current_span_id(self) -> Optional[str]: + current_thread = threading.get_ident() + if current_thread in self.current_span_ids: + return self.current_span_ids[current_thread] + return None + + def set_current_span_id(self, value: str): + current_thread = threading.get_ident() + self.current_span_ids[current_thread] = value + @property def asyncio_lock(self) -> asyncio.Lock: if self._asyncio_lock is None: self._asyncio_lock = asyncio.Lock() return self._asyncio_lock + @property + def lock(self) -> threading.Lock: + if self._lock is None: + self._lock = threading.Lock() + return self._lock + @property def parent(self) -> "Dispatcher": return self.manager.dispatchers[self.parent_name] @@ -191,95 +223,53 @@ class Dispatcher(BaseModel): functions only. Otherwise, the span_id should not be trusted, as the span decorator sets the span_id. """ - span_id = self.current_span_id + with self.lock: + span_id = self.current_span_id dispatch_event: EventDispatcher = partial(self.event, span_id=span_id) return dispatch_event - @contextmanager - def dispatch_event(self): - """Context manager for firing events within a span session. - - This context manager should be used with @dispatcher.span decorated - functions only. Otherwise, the span_id should not be trusted, as the - span decorator sets the span_id. - """ - span_id = self.current_span_id - dispatch_event: EventDispatcher = partial(self.event, span_id=span_id) - - try: - yield dispatch_event - finally: - del dispatch_event - - def async_span_with_parent_id(self, parent_id: str): - """This decorator should be used to span an async function nested in an outer span. - - Primary example: llama_index.core.async_utils.run_jobs - - Args: - parent_id (str): The span_id of the outer span. - """ - - def outer(func): - @wrapt.decorator - async def async_wrapper(func, instance, args, kwargs): - bound_args = inspect.signature(func).bind(*args, **kwargs) - id_ = f"{func.__qualname__}-{uuid.uuid4()}" - async with self.asyncio_lock: - self.current_span_id = id_ - async with self.root.asyncio_lock: - self.root.current_span_id = id_ + def _get_parent_update_span_ctx_var(self, id_: str, current_task_name: str): + """Helper method to get parent id from the appropriate async contextvar.""" + current_thread = threading.get_ident() + thread_span_ctx = span_ctx_var.get().copy() + span_ctx = thread_span_ctx[current_thread] + if current_task_name not in span_ctx: + parent_id = None + span_ctx[current_task_name] = [id_] + else: + parent_id = span_ctx[current_task_name][-1] + span_ctx[current_task_name].append(id_) + span_ctx_var.set(thread_span_ctx) - current_task = asyncio.current_task() - current_task_name = current_task.get_name() - span_ctx_dict = span_ctx.get().copy() - if current_task_name not in span_ctx_dict: - span_ctx_dict[current_task_name] = [id_] - else: - span_ctx_dict[current_task_name].append(id_) - span_ctx.set(span_ctx_dict) + return parent_id - self.span_enter( - id_=id_, - bound_args=bound_args, - instance=instance, - parent_id=parent_id, - ) - try: - result = await func(*args, **kwargs) - except BaseException as e: - self.event(SpanDropEvent(span_id=id_, err_str=str(e))) - self.span_drop( - id_=id_, bound_args=bound_args, instance=instance, err=e - ) - raise - else: - self.span_exit( - id_=id_, bound_args=bound_args, instance=instance, result=result - ) - return result - finally: - # clean up - current_task = asyncio.current_task() - current_task_name = current_task.get_name() - span_ctx_dict = span_ctx.get().copy() - span_ctx_dict[current_task_name].pop() - if len(span_ctx_dict[current_task_name]) == 0: - del span_ctx_dict[current_task_name] - span_ctx.set(span_ctx_dict) + def _pop_span_from_ctx_var(self, current_task_name: str) -> None: + """Helper method to pop completed/dropped span from async contextvar.""" + current_thread = threading.get_ident() + thread_span_ctx = span_ctx_var.get().copy() + span_ctx = thread_span_ctx[current_thread] - return async_wrapper(func) - - return outer + span_ctx[current_task_name].pop() + if len(span_ctx[current_task_name]) == 0: + del span_ctx[current_task_name] + span_ctx_var.set(thread_span_ctx) def span(self, func): @wrapt.decorator def wrapper(func, instance, args, kwargs): bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" - self.current_span_id = id_ - self.root.current_span_id = id_ - self.span_enter(id_=id_, bound_args=bound_args, instance=instance) + with self.lock: + self.set_current_span_id(id_) + with self.root.lock: + self.root.set_current_span_id(id_) + + # get parent_id (thread-safe) + parent_id = self._get_parent_update_span_ctx_var(id_, DEFAULT_SYNC_KEY) + + self.span_enter( + id_=id_, bound_args=bound_args, instance=instance, parent_id=parent_id + ) try: result = func(*args, **kwargs) except BaseException as e: @@ -291,27 +281,24 @@ class Dispatcher(BaseModel): id_=id_, bound_args=bound_args, instance=instance, result=result ) return result + finally: + # clean up + self._pop_span_from_ctx_var(DEFAULT_SYNC_KEY) @wrapt.decorator async def async_wrapper(func, instance, args, kwargs): bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" async with self.asyncio_lock: - self.current_span_id = id_ + self.set_current_span_id(id_) async with self.root.asyncio_lock: - self.root.current_span_id = id_ + self.root.set_current_span_id(id_) - # get parent_id + # get parent_id (thread and async-task safe) + # spans are managed in this hieararchy: thread > async task > async coros current_task = asyncio.current_task() current_task_name = current_task.get_name() - span_ctx_dict = span_ctx.get().copy() - if current_task_name not in span_ctx_dict: - parent_id = None - span_ctx_dict[current_task_name] = [id_] - else: - parent_id = span_ctx_dict[current_task_name][-1] - span_ctx_dict[current_task_name].append(id_) - span_ctx.set(span_ctx_dict) + parent_id = self._get_parent_update_span_ctx_var(id_, current_task_name) self.span_enter( id_=id_, bound_args=bound_args, instance=instance, parent_id=parent_id @@ -329,19 +316,64 @@ class Dispatcher(BaseModel): return result finally: # clean up - current_task = asyncio.current_task() - current_task_name = current_task.get_name() - span_ctx_dict = span_ctx.get().copy() - span_ctx_dict[current_task_name].pop() - if len(span_ctx_dict[current_task_name]) == 0: - del span_ctx_dict[current_task_name] - span_ctx.set(span_ctx_dict) + self._pop_span_from_ctx_var(current_task_name) if inspect.iscoroutinefunction(func): return async_wrapper(func) else: return wrapper(func) + def async_span_with_parent_id(self, parent_id: str): + """This decorator should be used to span an async function nested in an outer span. + + Primary example: llama_index.core.async_utils.run_jobs + + Args: + parent_id (str): The span_id of the outer span. + """ + + def outer(func): + @wrapt.decorator + async def async_wrapper(func, instance, args, kwargs): + bound_args = inspect.signature(func).bind(*args, **kwargs) + id_ = f"{func.__qualname__}-{uuid.uuid4()}" + async with self.asyncio_lock: + self.set_current_span_id(id_) + async with self.root.asyncio_lock: + self.root.set_current_span_id(id_) + + # don't need parent_id but need to update span ctx var + current_task = asyncio.current_task() + current_task_name = current_task.get_name() + _ = self._get_parent_update_span_ctx_var(id_, current_task_name) + + self.span_enter( + id_=id_, + bound_args=bound_args, + instance=instance, + parent_id=parent_id, + ) + try: + result = await func(*args, **kwargs) + except BaseException as e: + self.event(SpanDropEvent(span_id=id_, err_str=str(e))) + self.span_drop( + id_=id_, bound_args=bound_args, instance=instance, err=e + ) + raise + else: + self.span_exit( + id_=id_, bound_args=bound_args, instance=instance, result=result + ) + return result + finally: + # clean up + self._pop_span_from_ctx_var(current_task_name) + + return async_wrapper(func) + + return outer + @property def log_name(self) -> str: """Name to be used in logging.""" diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py index 82b5b487c84181c12ec52c75f442e2c192900556..e5a6710af7cdf8110468b1e9b8f944076f84f0b6 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py @@ -1,8 +1,9 @@ import inspect +import threading from abc import abstractmethod from typing import Any, Dict, List, Generic, Optional, TypeVar -from llama_index.core.bridge.pydantic import BaseModel, Field +from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr from llama_index.core.instrumentation.span.base import BaseSpan T = TypeVar("T", bound=BaseSpan) @@ -18,17 +19,50 @@ class BaseSpanHandler(BaseModel, Generic[T]): dropped_spans: List[T] = Field( default_factory=list, description="List of completed spans." ) - current_span_id: Optional[str] = Field( - default=None, description="Id of current span." + current_span_ids: Dict[Any, Optional[str]] = Field( + default={}, description="Id of current spans in a given thread." ) + _lock: Optional[threading.Lock] = PrivateAttr() class Config: arbitrary_types_allowed = True + def __init__( + self, + open_spans: Dict[str, T] = {}, + completed_spans: List[T] = [], + dropped_spans: List[T] = [], + current_span_ids: Dict[Any, str] = {}, + ): + self._lock = None + super().__init__( + open_spans=open_spans, + completed_spans=completed_spans, + dropped_spans=dropped_spans, + current_span_ids=current_span_ids, + ) + def class_name(cls) -> str: """Class name.""" return "BaseSpanHandler" + @property + def lock(self) -> threading.Lock: + if self._lock is None: + self._lock = threading.Lock() + return self._lock + + @property + def current_span_id(self) -> Optional[str]: + current_thread = threading.get_ident() + if current_thread in self.current_span_ids: + return self.current_span_ids[current_thread] + return None + + def set_current_span_id(self, value: str) -> None: + current_thread = threading.get_ident() + self.current_span_ids[current_thread] = value + def span_enter( self, id_: str, @@ -49,8 +83,9 @@ class BaseSpanHandler(BaseModel, Generic[T]): parent_span_id=parent_id or self.current_span_id, ) if span: - self.open_spans[id_] = span - self.current_span_id = id_ + with self.lock: + self.open_spans[id_] = span + self.set_current_span_id(id_) def span_exit( self, @@ -65,11 +100,13 @@ class BaseSpanHandler(BaseModel, Generic[T]): id_=id_, bound_args=bound_args, instance=instance, result=result ) if span: - if self.current_span_id == id_: - self.current_span_id = self.open_spans[id_].parent_id - del self.open_spans[id_] + with self.lock: + if self.current_span_id == id_: + self.set_current_span_id(self.open_spans[id_].parent_id) + del self.open_spans[id_] if not self.open_spans: # empty so flush - self.current_span_id = None + with self.lock: + self.set_current_span_id(None) def span_drop( self, @@ -84,9 +121,10 @@ class BaseSpanHandler(BaseModel, Generic[T]): id_=id_, bound_args=bound_args, instance=instance, err=err ) if span: - if self.current_span_id == id_: - self.current_span_id = self.open_spans[id_].parent_id - del self.open_spans[id_] + with self.lock: + if self.current_span_id == id_: + self.set_current_span_id(self.open_spans[id_].parent_id) + del self.open_spans[id_] @abstractmethod def new_span( @@ -97,7 +135,11 @@ class BaseSpanHandler(BaseModel, Generic[T]): parent_span_id: Optional[str] = None, **kwargs: Any, ) -> Optional[T]: - """Create a span.""" + """Create a span. + + Subclasses of BaseSpanHandler should create the respective span type T + and return it. Only NullSpanHandler should return a None here. + """ ... @abstractmethod @@ -109,7 +151,12 @@ class BaseSpanHandler(BaseModel, Generic[T]): result: Optional[Any] = None, **kwargs: Any, ) -> Optional[T]: - """Logic for preparing to exit a span.""" + """Logic for preparing to exit a span. + + Subclasses of BaseSpanHandler should return back the specific span T + that is to be exited. If None is returned, then the span won't actually + be exited. + """ ... @abstractmethod @@ -121,5 +168,10 @@ class BaseSpanHandler(BaseModel, Generic[T]): err: Optional[BaseException] = None, **kwargs: Any, ) -> Optional[T]: - """Logic for preparing to drop a span.""" + """Logic for preparing to drop a span. + + Subclasses of BaseSpanHandler should return back the specific span T + that is to be dropped. If None is returned, then the span won't actually + be dropped. + """ ... diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py index f0ae9b0fab43e60da6bd821c9e26b52e3a56e836..e870b4cced2d7eedefe86791a557a35644468300 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py @@ -3,6 +3,7 @@ from typing import Any, cast, List, Optional, TYPE_CHECKING from llama_index.core.instrumentation.span.simple import SimpleSpan from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler from datetime import datetime +from functools import reduce import warnings if TYPE_CHECKING: @@ -40,7 +41,8 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): span = cast(SimpleSpan, span) span.end_time = datetime.now() span.duration = (span.end_time - span.start_time).total_seconds() - self.completed_spans += [span] + with self.lock: + self.completed_spans += [span] return span def prepare_to_drop_span( @@ -53,55 +55,79 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): ) -> SimpleSpan: """Logic for droppping a span.""" if id_ in self.open_spans: - span = self.open_spans[id_] - span.metadata = {"error": str(err)} - self.dropped_spans += [span] + with self.lock: + span = self.open_spans[id_] + span.metadata = {"error": str(err)} + self.dropped_spans += [span] return span return None + def _get_parents(self) -> List[SimpleSpan]: + """Helper method to get all parent/root spans.""" + all_spans = self.completed_spans + self.dropped_spans + return [s for s in all_spans if s.parent_id is None] + + def _build_tree_by_parent( + self, parent: SimpleSpan, acc: List[SimpleSpan], spans: List[SimpleSpan] + ) -> List[SimpleSpan]: + """Builds the tree by parent root.""" + if not spans: + return acc + + children = [s for s in spans if s.parent_id == parent.id_] + if not children: + return acc + updated_spans = [s for s in spans if s not in children] + + children_trees = [ + self._build_tree_by_parent( + parent=c, acc=[c], spans=[s for s in updated_spans if c != s] + ) + for c in children + ] + + return acc + reduce(lambda x, y: x + y, children_trees) + def _get_trace_trees(self) -> List["Tree"]: """Method for getting trace trees.""" try: from treelib import Tree - from treelib.exceptions import NodeIDAbsentError except ImportError as e: raise ImportError( "`treelib` package is missing. Please install it by using " "`pip install treelib`." ) - sorted_spans = sorted( - self.completed_spans + self.dropped_spans, key=lambda x: x.start_time - ) + + all_spans = self.completed_spans + self.dropped_spans + for s in all_spans: + if s.parent_id is None: + continue + if not any(ns.id_ == s.parent_id for ns in all_spans): + warnings.warn("Parent with id {span.parent_id} missing from spans") + s.parent_id += "-MISSING" + all_spans.append(SimpleSpan(id_=s.parent_id, parent_id=None)) + + parents = self._get_parents() + span_groups = [] + for p in parents: + this_span_group = self._build_tree_by_parent( + parent=p, acc=[p], spans=[s for s in all_spans if s != p] + ) + sorted_span_group = sorted(this_span_group, key=lambda x: x.start_time) + span_groups.append(sorted_span_group) trees = [] tree = Tree() - for span in sorted_spans: - if span.parent_id is None: - # complete old tree unless its empty (i.e., start of loop) - if tree.all_nodes(): - trees.append(tree) - # start new tree - tree = Tree() - - try: - tree.create_node( - tag=f"{span.id_} ({span.duration})", - identifier=span.id_, - parent=span.parent_id, - data=span.start_time, - ) - except NodeIDAbsentError: - warnings.warn("Parent with id {span.parent_id} missing from spans") - # create new tree and fake parent node - trees.append(tree) - tree = Tree() - tree.create_node( - tag=f"{span.parent_id} (MISSING)", - identifier=span.parent_id, - parent=None, - data=span.start_time, - ) + for grp in span_groups: + for span in grp: + if span.parent_id is None: + # complete old tree unless its empty (i.e., start of loop) + if tree.all_nodes(): + trees.append(tree) + # start new tree + tree = Tree() + tree.create_node( tag=f"{span.id_} ({span.duration})", identifier=span.id_, diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py index a59b917441b2b20f132a75e5b7fc467c000e1dc1..c07b7208fdf4e7be048539f498ec65df79b23df6 100644 --- a/llama-index-core/tests/instrumentation/test_dispatcher.py +++ b/llama-index-core/tests/instrumentation/test_dispatcher.py @@ -126,7 +126,12 @@ def test_dispatcher_span_args(mock_uuid, mock_span_enter, mock_span_exit): mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () - assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": None} + assert kwargs == { + "id_": span_id, + "bound_args": bound_args, + "instance": None, + "parent_id": None, + } # span_exit args, kwargs = mock_span_exit.call_args @@ -157,7 +162,12 @@ def test_dispatcher_span_args_with_instance(mock_uuid, mock_span_enter, mock_spa mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () - assert kwargs == {"id_": span_id, "bound_args": bound_args, "instance": instance} + assert kwargs == { + "id_": span_id, + "bound_args": bound_args, + "instance": instance, + "parent_id": None, + } # span_exit args, kwargs = mock_span_exit.call_args