From 95be930f5f715ba2268483b4c31b55c4c78e3e6d Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Tue, 2 Apr 2024 00:20:48 -0400 Subject: [PATCH] Add `span_id` attribute to Events (instrumentation) (#12417) * add span_id to Event * remove raise err in NullHandler * wip * modify root dispatcher event enclosing span * remove *args as we have bound_args now * add LLMChatInProgressEvent * add LLMStructuredPredict Eventst * store span_id before await executions * add SpanDropEvent with err_str payload * add event to _achat; flush current_span_id when open_spans is empty * llm callbacks use root span_id * add unit tests * remove print statements * provide context manager returning a distpatch event partial with correct span id * move to context manager usage * fix invocation of cm * define and use get_dispatch_event method * remove aim tests --- .../llama_index/core/agent/runner/base.py | 23 +- .../core/base/base_query_engine.py | 12 +- .../llama_index/core/base/base_retriever.py | 30 ++- .../llama_index/core/base/embeddings/base.py | 90 +++++-- .../llama_index/core/chat_engine/types.py | 26 +- .../core/instrumentation/__init__.py | 5 +- .../core/instrumentation/dispatcher.py | 92 ++++++- .../core/instrumentation/events/base.py | 1 + .../core/instrumentation/events/llm.py | 24 ++ .../core/instrumentation/events/span.py | 10 + .../core/instrumentation/span/simple.py | 3 +- .../instrumentation/span_handlers/base.py | 16 +- .../instrumentation/span_handlers/null.py | 8 +- .../instrumentation/span_handlers/simple.py | 18 +- .../llama_index/core/llms/callbacks.py | 37 ++- llama-index-core/llama_index/core/llms/llm.py | 29 ++- .../core/response_synthesizers/base.py | 64 ++++- .../core/response_synthesizers/refine.py | 11 +- .../tests/instrumentation/test_dispatcher.py | 242 ++++++++++++++++-- .../llama-index-callbacks-aim/tests/BUILD | 1 - .../tests/__init__.py | 0 .../tests/test_aim_callback.py | 7 - 22 files changed, 617 insertions(+), 132 deletions(-) create mode 100644 llama-index-core/llama_index/core/instrumentation/events/span.py delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py index c5759ffed4..2f17a796d6 100644 --- a/llama-index-core/llama_index/core/agent/runner/base.py +++ b/llama-index-core/llama_index/core/agent/runner/base.py @@ -365,7 +365,9 @@ class AgentRunner(BaseAgentRunner): **kwargs: Any, ) -> TaskStepOutput: """Execute step.""" - dispatcher.event(AgentRunStepStartEvent()) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(AgentRunStepStartEvent()) task = self.state.get_task(task_id) step_queue = self.state.get_step_queue(task_id) step = step or step_queue.popleft() @@ -392,7 +394,7 @@ class AgentRunner(BaseAgentRunner): completed_steps = self.state.get_completed_steps(task_id) completed_steps.append(cur_step_output) - dispatcher.event(AgentRunStepEndEvent()) + dispatch_event(AgentRunStepEndEvent()) return cur_step_output @dispatcher.span @@ -405,6 +407,9 @@ class AgentRunner(BaseAgentRunner): **kwargs: Any, ) -> TaskStepOutput: """Execute step.""" + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(AgentRunStepStartEvent()) task = self.state.get_task(task_id) step_queue = self.state.get_step_queue(task_id) step = step or step_queue.popleft() @@ -430,6 +435,7 @@ class AgentRunner(BaseAgentRunner): completed_steps = self.state.get_completed_steps(task_id) completed_steps.append(cur_step_output) + dispatch_event(AgentRunStepEndEvent()) return cur_step_output @dispatcher.span @@ -528,12 +534,14 @@ class AgentRunner(BaseAgentRunner): mode: ChatResponseMode = ChatResponseMode.WAIT, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" + dispatch_event = dispatcher.get_dispatch_event() + if chat_history is not None: self.memory.set(chat_history) task = self.create_task(message) result_output = None - dispatcher.event(AgentChatWithStepStartEvent()) + dispatch_event(AgentChatWithStepStartEvent()) while True: # pass step queue in as argument, assume step executor is stateless cur_step_output = self._run_step( @@ -551,7 +559,7 @@ class AgentRunner(BaseAgentRunner): task.task_id, result_output, ) - dispatcher.event(AgentChatWithStepEndEvent()) + dispatch_event(AgentChatWithStepEndEvent()) return result @dispatcher.span @@ -563,11 +571,14 @@ class AgentRunner(BaseAgentRunner): mode: ChatResponseMode = ChatResponseMode.WAIT, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" + dispatch_event = dispatcher.get_dispatch_event() + if chat_history is not None: self.memory.set(chat_history) task = self.create_task(message) result_output = None + dispatch_event(AgentChatWithStepStartEvent()) while True: # pass step queue in as argument, assume step executor is stateless cur_step_output = await self._arun_step( @@ -581,10 +592,12 @@ class AgentRunner(BaseAgentRunner): # ensure tool_choice does not cause endless loops tool_choice = "auto" - return self.finalize_response( + result = self.finalize_response( task.task_id, result_output, ) + dispatch_event(AgentChatWithStepEndEvent()) + return result @dispatcher.span @trace_method("chat") diff --git a/llama-index-core/llama_index/core/base/base_query_engine.py b/llama-index-core/llama_index/core/base/base_query_engine.py index b81fd17396..42f6b33e8c 100644 --- a/llama-index-core/llama_index/core/base/base_query_engine.py +++ b/llama-index-core/llama_index/core/base/base_query_engine.py @@ -44,22 +44,26 @@ class BaseQueryEngine(ChainableMixin, PromptMixin): @dispatcher.span def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: - dispatcher.event(QueryStartEvent()) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(QueryStartEvent()) with self.callback_manager.as_trace("query"): if isinstance(str_or_query_bundle, str): str_or_query_bundle = QueryBundle(str_or_query_bundle) query_result = self._query(str_or_query_bundle) - dispatcher.event(QueryEndEvent()) + dispatch_event(QueryEndEvent()) return query_result @dispatcher.span async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: - dispatcher.event(QueryStartEvent()) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(QueryStartEvent()) with self.callback_manager.as_trace("query"): if isinstance(str_or_query_bundle, str): str_or_query_bundle = QueryBundle(str_or_query_bundle) query_result = await self._aquery(str_or_query_bundle) - dispatcher.event(QueryEndEvent()) + dispatch_event(QueryEndEvent()) return query_result def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: diff --git a/llama-index-core/llama_index/core/base/base_retriever.py b/llama-index-core/llama_index/core/base/base_retriever.py index 6dc7649de6..f87872b0ff 100644 --- a/llama-index-core/llama_index/core/base/base_retriever.py +++ b/llama-index-core/llama_index/core/base/base_retriever.py @@ -224,8 +224,14 @@ class BaseRetriever(ChainableMixin, PromptMixin): a QueryBundle object. """ + dispatch_event = dispatcher.get_dispatch_event() + self._check_callback_manager() - dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle)) + dispatch_event( + RetrievalStartEvent( + str_or_query_bundle=str_or_query_bundle, + ) + ) if isinstance(str_or_query_bundle, str): query_bundle = QueryBundle(str_or_query_bundle) else: @@ -240,15 +246,24 @@ class BaseRetriever(ChainableMixin, PromptMixin): retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) - dispatcher.event( - RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes) + dispatch_event( + RetrievalEndEvent( + str_or_query_bundle=str_or_query_bundle, + nodes=nodes, + ) ) return nodes @dispatcher.span async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: self._check_callback_manager() - dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle)) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + RetrievalStartEvent( + str_or_query_bundle=str_or_query_bundle, + ) + ) if isinstance(str_or_query_bundle, str): query_bundle = QueryBundle(str_or_query_bundle) else: @@ -265,8 +280,11 @@ class BaseRetriever(ChainableMixin, PromptMixin): retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) - dispatcher.event( - RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes) + dispatch_event( + RetrievalEndEvent( + str_or_query_bundle=str_or_query_bundle, + nodes=nodes, + ) ) return nodes diff --git a/llama-index-core/llama_index/core/base/embeddings/base.py b/llama-index-core/llama_index/core/base/embeddings/base.py index 5924aca1a3..3f5c0cf849 100644 --- a/llama-index-core/llama_index/core/base/embeddings/base.py +++ b/llama-index-core/llama_index/core/base/embeddings/base.py @@ -114,7 +114,13 @@ class BaseEmbedding(TransformComponent): other examples of predefined instructions can be found in embeddings/huggingface_utils.py. """ - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} ) as event: @@ -126,15 +132,24 @@ class BaseEmbedding(TransformComponent): EventPayload.EMBEDDINGS: [query_embedding], }, ) - dispatcher.event( - EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding]) + dispatch_event( + EmbeddingEndEvent( + chunks=[query], + embeddings=[query_embedding], + ) ) return query_embedding @dispatcher.span async def aget_query_embedding(self, query: str) -> Embedding: """Get query embedding.""" - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} ) as event: @@ -146,8 +161,11 @@ class BaseEmbedding(TransformComponent): EventPayload.EMBEDDINGS: [query_embedding], }, ) - dispatcher.event( - EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding]) + dispatch_event( + EmbeddingEndEvent( + chunks=[query], + embeddings=[query_embedding], + ) ) return query_embedding @@ -220,7 +238,13 @@ class BaseEmbedding(TransformComponent): document for retrieval: ". If you're curious, other examples of predefined instructions can be found in embeddings/huggingface_utils.py. """ - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} ) as event: @@ -232,13 +256,24 @@ class BaseEmbedding(TransformComponent): EventPayload.EMBEDDINGS: [text_embedding], } ) - dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding])) + dispatch_event( + EmbeddingEndEvent( + chunks=[text], + embeddings=[text_embedding], + ) + ) return text_embedding @dispatcher.span async def aget_text_embedding(self, text: str) -> Embedding: """Async get text embedding.""" - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} ) as event: @@ -250,7 +285,12 @@ class BaseEmbedding(TransformComponent): EventPayload.EMBEDDINGS: [text_embedding], } ) - dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding])) + dispatch_event( + EmbeddingEndEvent( + chunks=[text], + embeddings=[text_embedding], + ) + ) return text_embedding @dispatcher.span @@ -261,6 +301,8 @@ class BaseEmbedding(TransformComponent): **kwargs: Any, ) -> List[Embedding]: """Get a list of text embeddings, with batching.""" + dispatch_event = dispatcher.get_dispatch_event() + cur_batch: List[str] = [] result_embeddings: List[Embedding] = [] @@ -272,7 +314,11 @@ class BaseEmbedding(TransformComponent): cur_batch.append(text) if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: # flush - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}, @@ -285,8 +331,11 @@ class BaseEmbedding(TransformComponent): EventPayload.EMBEDDINGS: embeddings, }, ) - dispatcher.event( - EmbeddingEndEvent(chunks=cur_batch, embeddings=embeddings) + dispatch_event( + EmbeddingEndEvent( + chunks=cur_batch, + embeddings=embeddings, + ) ) cur_batch = [] @@ -297,6 +346,8 @@ class BaseEmbedding(TransformComponent): self, texts: List[str], show_progress: bool = False ) -> List[Embedding]: """Asynchronously get a list of text embeddings, with batching.""" + dispatch_event = dispatcher.get_dispatch_event() + cur_batch: List[str] = [] callback_payloads: List[Tuple[str, List[str]]] = [] result_embeddings: List[Embedding] = [] @@ -305,7 +356,11 @@ class BaseEmbedding(TransformComponent): cur_batch.append(text) if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: # flush - dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict())) + dispatch_event( + EmbeddingStartEvent( + model_dict=self.to_dict(), + ) + ) event_id = self.callback_manager.on_event_start( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}, @@ -337,8 +392,11 @@ class BaseEmbedding(TransformComponent): for (event_id, text_batch), embeddings in zip( callback_payloads, nested_embeddings ): - dispatcher.event( - EmbeddingEndEvent(chunks=text_batch, embeddings=embeddings) + dispatch_event( + EmbeddingEndEvent( + chunks=text_batch, + embeddings=embeddings, + ) ) self.callback_manager.on_event_end( CBEventType.EMBEDDING, diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index 9291565269..72753cb92e 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -127,15 +127,20 @@ class StreamingAgentChatResponse: raise ValueError( "chat_stream is None. Cannot write to history without chat_stream." ) + dispatch_event = dispatcher.get_dispatch_event() # try/except to prevent hanging on error - dispatcher.event(StreamChatStartEvent()) + dispatch_event(StreamChatStartEvent()) try: final_text = "" for chat in self.chat_stream: self._is_function = is_function(chat.message) if chat.delta: - dispatcher.event(StreamChatDeltaReceivedEvent(delta=chat.delta)) + dispatch_event( + StreamChatDeltaReceivedEvent( + delta=chat.delta, + ) + ) self.put_in_queue(chat.delta) final_text += chat.delta or "" if self._is_function is not None: # if loop has gone through iteration @@ -144,14 +149,14 @@ class StreamingAgentChatResponse: chat.message.content = final_text.strip() # final message memory.put(chat.message) except Exception as e: - dispatcher.event(StreamChatErrorEvent()) + dispatch_event(StreamChatErrorEvent()) if not raise_error: logger.warning( f"Encountered exception writing response to history: {e}" ) else: raise - dispatcher.event(StreamChatEndEvent()) + dispatch_event(StreamChatEndEvent()) self._is_done = True @@ -167,6 +172,7 @@ class StreamingAgentChatResponse: on_stream_end_fn: Optional[callable] = None, ) -> None: self._ensure_async_setup() + dispatch_event = dispatcher.get_dispatch_event() if self.achat_stream is None: raise ValueError( @@ -175,13 +181,17 @@ class StreamingAgentChatResponse: ) # try/except to prevent hanging on error - dispatcher.event(StreamChatStartEvent()) + dispatch_event(StreamChatStartEvent()) try: final_text = "" async for chat in self.achat_stream: self._is_function = is_function(chat.message) if chat.delta: - dispatcher.event(StreamChatDeltaReceivedEvent(delta=chat.delta)) + dispatch_event( + StreamChatDeltaReceivedEvent( + delta=chat.delta, + ) + ) self.aput_in_queue(chat.delta) final_text += chat.delta or "" self._new_item_event.set() @@ -193,9 +203,9 @@ class StreamingAgentChatResponse: chat.message.content = final_text.strip() # final message memory.put(chat.message) except Exception as e: - dispatcher.event(StreamChatErrorEvent()) + dispatch_event(StreamChatErrorEvent()) logger.warning(f"Encountered exception writing response to history: {e}") - dispatcher.event(StreamChatEndEvent()) + dispatch_event(StreamChatEndEvent()) self._is_done = True # These act as is_done events for any consumers waiting diff --git a/llama-index-core/llama_index/core/instrumentation/__init__.py b/llama-index-core/llama_index/core/instrumentation/__init__.py index b7564a02f7..fd41c2ec17 100644 --- a/llama-index-core/llama_index/core/instrumentation/__init__.py +++ b/llama-index-core/llama_index/core/instrumentation/__init__.py @@ -24,7 +24,10 @@ def get_dispatcher(name: str = "root") -> Dispatcher: parent_name = "root" new_dispatcher = Dispatcher( - name=name, root=root_dispatcher, parent_name=parent_name, manager=root_manager + name=name, + root_name=root_dispatcher.name, + parent_name=parent_name, + manager=root_manager, ) root_manager.add_dispatcher(new_dispatcher) return new_dispatcher diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index 727d7cba82..c4b659aa06 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -1,16 +1,33 @@ -from typing import Any, List, Optional, Dict +from typing import Any, List, Optional, Dict, Protocol +from functools import partial +from contextlib import contextmanager +import asyncio import inspect import uuid -from llama_index.core.bridge.pydantic import BaseModel, Field +from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr +from llama_index.core.instrumentation.events import BaseEvent from llama_index.core.instrumentation.event_handlers import BaseEventHandler from llama_index.core.instrumentation.span_handlers import ( BaseSpanHandler, NullSpanHandler, ) from llama_index.core.instrumentation.events.base import BaseEvent +from llama_index.core.instrumentation.events.span import SpanDropEvent import wrapt +class EventDispatcher(Protocol): + def __call__(self, event: BaseEvent) -> None: + ... + + +class EventContext(BaseModel): + span_id: str = Field(default="") + + +event_context = EventContext() + + class Dispatcher(BaseModel): name: str = Field(default_factory=str, description="Name of dispatcher") event_handlers: List[BaseEventHandler] = Field( @@ -30,6 +47,31 @@ class Dispatcher(BaseModel): default=True, description="Whether to propagate the event to parent dispatchers and their handlers", ) + current_span_id: Optional[str] = Field( + default=None, description="Id of current span." + ) + _asyncio_lock: asyncio.Lock = PrivateAttr() + + def __init__( + self, + name: str = "", + event_handlers: List[BaseEventHandler] = [], + span_handlers: List[BaseSpanHandler] = [], + parent_name: str = "", + manager: Optional["Manager"] = None, + root_name: str = "root", + propagate: bool = True, + ): + self._asyncio_lock = asyncio.Lock() + super().__init__( + name=name, + event_handlers=event_handlers, + span_handlers=span_handlers, + parent_name=parent_name, + manager=manager, + root_name=root_name, + propagate=propagate, + ) @property def parent(self) -> "Dispatcher": @@ -47,9 +89,11 @@ class Dispatcher(BaseModel): """Add handler to set of handlers.""" self.span_handlers += [handler] - def event(self, event: BaseEvent, **kwargs) -> None: + def event(self, event: BaseEvent, span_id: Optional[str] = None, **kwargs) -> None: """Dispatch event to all registered handlers.""" c = self + if span_id: + event.span_id = span_id while c: for h in c.event_handlers: h.handle(event, **kwargs) @@ -60,7 +104,6 @@ class Dispatcher(BaseModel): def span_enter( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -71,7 +114,6 @@ class Dispatcher(BaseModel): while c: for h in c.span_handlers: h.span_enter( - *args, id_=id_, bound_args=bound_args, instance=instance, @@ -84,7 +126,6 @@ class Dispatcher(BaseModel): def span_drop( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -96,7 +137,6 @@ class Dispatcher(BaseModel): while c: for h in c.span_handlers: h.span_drop( - *args, id_=id_, bound_args=bound_args, instance=instance, @@ -110,7 +150,6 @@ class Dispatcher(BaseModel): def span_exit( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -122,7 +161,6 @@ class Dispatcher(BaseModel): while c: for h in c.span_handlers: h.span_exit( - *args, id_=id_, bound_args=bound_args, instance=instance, @@ -134,15 +172,45 @@ class Dispatcher(BaseModel): else: c = c.parent + def get_dispatch_event(self) -> EventDispatcher: + """Get dispatch_event for firing events within the context of a span. + + This method should be used with @dispatcher.span decorated + functions only. Otherwise, the span_id should not be trusted, as the + span decorator sets the span_id. + """ + span_id = self.current_span_id + dispatch_event: EventDispatcher = partial(self.event, span_id=span_id) + return dispatch_event + + @contextmanager + def dispatch_event(self): + """Context manager for firing events within a span session. + + This context manager should be used with @dispatcher.span decorated + functions only. Otherwise, the span_id should not be trusted, as the + span decorator sets the span_id. + """ + span_id = self.current_span_id + dispatch_event: EventDispatcher = partial(self.event, span_id=span_id) + + try: + yield dispatch_event + finally: + del dispatch_event + def span(self, func): @wrapt.decorator def wrapper(func, instance, args, kwargs): bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" + self.current_span_id = id_ + self.root.current_span_id = id_ self.span_enter(id_=id_, bound_args=bound_args, instance=instance) try: result = func(*args, **kwargs) except BaseException as e: + self.event(SpanDropEvent(span_id=id_, err_str=str(e))) self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e) raise else: @@ -155,10 +223,16 @@ class Dispatcher(BaseModel): async def async_wrapper(func, instance, args, kwargs): bound_args = inspect.signature(func).bind(*args, **kwargs) id_ = f"{func.__qualname__}-{uuid.uuid4()}" + async with self._asyncio_lock: + self.current_span_id = id_ + async with self.root._asyncio_lock: + self.root.current_span_id = id_ + self.span_enter(id_=id_, bound_args=bound_args, instance=instance) try: result = await func(*args, **kwargs) except BaseException as e: + self.event(SpanDropEvent(span_id=id_, err_str=str(e))) self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e) raise else: diff --git a/llama-index-core/llama_index/core/instrumentation/events/base.py b/llama-index-core/llama_index/core/instrumentation/events/base.py index 061aea2605..28f3d1dba7 100644 --- a/llama-index-core/llama_index/core/instrumentation/events/base.py +++ b/llama-index-core/llama_index/core/instrumentation/events/base.py @@ -7,6 +7,7 @@ from datetime import datetime class BaseEvent(BaseModel): timestamp: datetime = Field(default_factory=lambda: datetime.now()) id_: str = Field(default_factory=lambda: uuid4()) + span_id: str = Field(default_factory=str) @classmethod def class_name(cls): diff --git a/llama-index-core/llama_index/core/instrumentation/events/llm.py b/llama-index-core/llama_index/core/instrumentation/events/llm.py index b61dde1b84..28a81204a5 100644 --- a/llama-index-core/llama_index/core/instrumentation/events/llm.py +++ b/llama-index-core/llama_index/core/instrumentation/events/llm.py @@ -22,6 +22,20 @@ class LLMPredictEndEvent(BaseEvent): return "LLMPredictEndEvent" +class LLMStructuredPredictStartEvent(BaseEvent): + @classmethod + def class_name(cls): + """Class name.""" + return "LLMStructuredPredictStartEvent" + + +class LLMStructuredPredictEndEvent(BaseEvent): + @classmethod + def class_name(cls): + """Class name.""" + return "LLMStructuredPredictEndEvent" + + class LLMCompletionStartEvent(BaseEvent): prompt: str additional_kwargs: dict @@ -54,6 +68,16 @@ class LLMChatStartEvent(BaseEvent): return "LLMChatStartEvent" +class LLMChatInProgressEvent(BaseEvent): + messages: List[ChatMessage] + response: ChatResponse + + @classmethod + def class_name(cls): + """Class name.""" + return "LLMChatInProgressEvent" + + class LLMChatEndEvent(BaseEvent): messages: List[ChatMessage] response: ChatResponse diff --git a/llama-index-core/llama_index/core/instrumentation/events/span.py b/llama-index-core/llama_index/core/instrumentation/events/span.py new file mode 100644 index 0000000000..113932e5b0 --- /dev/null +++ b/llama-index-core/llama_index/core/instrumentation/events/span.py @@ -0,0 +1,10 @@ +from llama_index.core.instrumentation.events.base import BaseEvent + + +class SpanDropEvent(BaseEvent): + err_str: str + + @classmethod + def class_name(cls): + """Class name.""" + return "SpanDropEvent" diff --git a/llama-index-core/llama_index/core/instrumentation/span/simple.py b/llama-index-core/llama_index/core/instrumentation/span/simple.py index ceef9338e1..7905b743f2 100644 --- a/llama-index-core/llama_index/core/instrumentation/span/simple.py +++ b/llama-index-core/llama_index/core/instrumentation/span/simple.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional from llama_index.core.bridge.pydantic import Field from llama_index.core.instrumentation.span.base import BaseSpan from datetime import datetime @@ -10,3 +10,4 @@ class SimpleSpan(BaseSpan): start_time: datetime = Field(default_factory=lambda: datetime.now()) end_time: Optional[datetime] = Field(default=None) duration: float = Field(default=float, description="Duration of span in seconds.") + metadata: Optional[Dict] = Field(default=None) diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py index 78beb7c295..725cfaa41f 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py @@ -1,6 +1,6 @@ import inspect from abc import abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Dict, List, Generic, Optional, TypeVar from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.instrumentation.span.base import BaseSpan @@ -12,6 +12,12 @@ class BaseSpanHandler(BaseModel, Generic[T]): open_spans: Dict[str, T] = Field( default_factory=dict, description="Dictionary of open spans." ) + completed_spans: List[T] = Field( + default_factory=list, description="List of completed spans." + ) + dropped_spans: List[T] = Field( + default_factory=list, description="List of completed spans." + ) current_span_id: Optional[str] = Field( default=None, description="Id of current span." ) @@ -25,7 +31,6 @@ class BaseSpanHandler(BaseModel, Generic[T]): def span_enter( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -48,7 +53,6 @@ class BaseSpanHandler(BaseModel, Generic[T]): def span_exit( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -63,10 +67,11 @@ class BaseSpanHandler(BaseModel, Generic[T]): if self.current_span_id == id_: self.current_span_id = self.open_spans[id_].parent_id del self.open_spans[id_] + if not self.open_spans: # empty so flush + self.current_span_id = None def span_drop( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -85,7 +90,6 @@ class BaseSpanHandler(BaseModel, Generic[T]): @abstractmethod def new_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -98,7 +102,6 @@ class BaseSpanHandler(BaseModel, Generic[T]): @abstractmethod def prepare_to_exit_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -111,7 +114,6 @@ class BaseSpanHandler(BaseModel, Generic[T]): @abstractmethod def prepare_to_drop_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py index a4c6000ebc..02788db0e0 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py @@ -12,7 +12,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): def span_enter( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -23,7 +22,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): def span_exit( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -35,7 +33,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): def new_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -47,7 +44,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): def prepare_to_exit_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -59,7 +55,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): def prepare_to_drop_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -67,5 +62,4 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): **kwargs: Any ) -> None: """Logic for droppping a span.""" - if err: - raise err + return diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py index 18f9bddc4f..f0ae9b0fab 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py @@ -1,6 +1,5 @@ import inspect from typing import Any, cast, List, Optional, TYPE_CHECKING -from llama_index.core.bridge.pydantic import Field from llama_index.core.instrumentation.span.simple import SimpleSpan from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler from datetime import datetime @@ -13,17 +12,12 @@ if TYPE_CHECKING: class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): """Span Handler that managest SimpleSpan's.""" - completed_spans: List[SimpleSpan] = Field( - default_factory=list, description="List of completed spans." - ) - def class_name(cls) -> str: """Class name.""" return "SimpleSpanHandler" def new_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -35,7 +29,6 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): def prepare_to_exit_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -52,7 +45,6 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): def prepare_to_drop_span( self, - *args: Any, id_: str, bound_args: inspect.BoundArguments, instance: Optional[Any] = None, @@ -61,7 +53,11 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): ) -> SimpleSpan: """Logic for droppping a span.""" if id_ in self.open_spans: - return self.open_spans[id_] + span = self.open_spans[id_] + span.metadata = {"error": str(err)} + self.dropped_spans += [span] + return span + return None def _get_trace_trees(self) -> List["Tree"]: @@ -74,7 +70,9 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): "`treelib` package is missing. Please install it by using " "`pip install treelib`." ) - sorted_spans = sorted(self.completed_spans, key=lambda x: x.start_time) + sorted_spans = sorted( + self.completed_spans + self.dropped_spans, key=lambda x: x.start_time + ) trees = [] tree = Tree() diff --git a/llama-index-core/llama_index/core/llms/callbacks.py b/llama-index-core/llama_index/core/llms/callbacks.py index e3602007d1..1b9075ee6e 100644 --- a/llama-index-core/llama_index/core/llms/callbacks.py +++ b/llama-index-core/llama_index/core/llms/callbacks.py @@ -27,6 +27,7 @@ from llama_index.core.instrumentation.events.llm import ( LLMCompletionStartEvent, LLMChatEndEvent, LLMChatStartEvent, + LLMChatInProgressEvent, ) dispatcher = get_dispatcher(__name__) @@ -49,11 +50,13 @@ def llm_chat_callback() -> Callable: _self: Any, messages: Sequence[ChatMessage], **kwargs: Any ) -> Any: with wrapper_logic(_self) as callback_manager: + span_id = dispatcher.root.current_span_id or "" dispatcher.event( LLMChatStartEvent( model_dict=_self.to_dict(), messages=messages, additional_kwargs=kwargs, + span_id=span_id, ) ) event_id = callback_manager.on_event_start( @@ -72,9 +75,10 @@ def llm_chat_callback() -> Callable: last_response = None async for x in f_return_val: dispatcher.event( - LLMChatEndEvent( + LLMChatInProgressEvent( messages=messages, response=x, + span_id=span_id, ) ) yield cast(ChatResponse, x) @@ -88,6 +92,13 @@ def llm_chat_callback() -> Callable: }, event_id=event_id, ) + dispatcher.event( + LLMChatEndEvent( + messages=messages, + response=x, + span_id=span_id, + ) + ) return wrapped_gen() else: @@ -103,6 +114,7 @@ def llm_chat_callback() -> Callable: LLMChatEndEvent( messages=messages, response=f_return_val, + span_id=span_id, ) ) @@ -112,11 +124,13 @@ def llm_chat_callback() -> Callable: _self: Any, messages: Sequence[ChatMessage], **kwargs: Any ) -> Any: with wrapper_logic(_self) as callback_manager: + span_id = dispatcher.root.current_span_id or "" dispatcher.event( LLMChatStartEvent( model_dict=_self.to_dict(), messages=messages, additional_kwargs=kwargs, + span_id=span_id, ) ) event_id = callback_manager.on_event_start( @@ -135,9 +149,10 @@ def llm_chat_callback() -> Callable: last_response = None for x in f_return_val: dispatcher.event( - LLMChatEndEvent( + LLMChatInProgressEvent( messages=messages, response=x, + span_id=span_id, ) ) yield cast(ChatResponse, x) @@ -151,6 +166,13 @@ def llm_chat_callback() -> Callable: }, event_id=event_id, ) + dispatcher.event( + LLMChatEndEvent( + messages=messages, + response=x, + span_id=span_id, + ) + ) return wrapped_gen() else: @@ -166,6 +188,7 @@ def llm_chat_callback() -> Callable: LLMChatEndEvent( messages=messages, response=f_return_val, + span_id=span_id, ) ) @@ -213,11 +236,13 @@ def llm_completion_callback() -> Callable: _self: Any, *args: Any, **kwargs: Any ) -> Any: with wrapper_logic(_self) as callback_manager: + span_id = dispatcher.root.current_span_id or "" dispatcher.event( LLMCompletionStartEvent( model_dict=_self.to_dict(), prompt=str(args[0]), additional_kwargs=kwargs, + span_id=span_id, ) ) event_id = callback_manager.on_event_start( @@ -240,6 +265,7 @@ def llm_completion_callback() -> Callable: LLMCompletionEndEvent( prompt=str(args[0]), response=x, + span_id=span_id, ) ) yield cast(CompletionResponse, x) @@ -268,6 +294,7 @@ def llm_completion_callback() -> Callable: LLMCompletionEndEvent( prompt=str(args[0]), response=f_return_val, + span_id=span_id, ) ) @@ -275,11 +302,13 @@ def llm_completion_callback() -> Callable: def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any: with wrapper_logic(_self) as callback_manager: + span_id = dispatcher.root.current_span_id or "" dispatcher.event( LLMCompletionStartEvent( model_dict=_self.to_dict(), prompt=str(args[0]), additional_kwargs=kwargs, + span_id=span_id, ) ) event_id = callback_manager.on_event_start( @@ -299,8 +328,7 @@ def llm_completion_callback() -> Callable: for x in f_return_val: dispatcher.event( LLMCompletionEndEvent( - prompt=str(args[0]), - response=x, + prompt=str(args[0]), response=x, span_id=span_id ) ) yield cast(CompletionResponse, x) @@ -329,6 +357,7 @@ def llm_completion_callback() -> Callable: LLMCompletionEndEvent( prompt=str(args[0]), response=f_return_val, + span_id=span_id, ) ) diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py index 678d9dfcea..b8349ea892 100644 --- a/llama-index-core/llama_index/core/llms/llm.py +++ b/llama-index-core/llama_index/core/llms/llm.py @@ -51,6 +51,9 @@ from llama_index.core.types import ( ) from llama_index.core.instrumentation.events.llm import ( LLMPredictEndEvent, + LLMPredictStartEvent, + LLMStructuredPredictEndEvent, + LLMStructuredPredictStartEvent, ) import llama_index.core.instrumentation as instrument @@ -323,6 +326,9 @@ class LLM(BaseLLM): """ from llama_index.core.program.utils import get_program_for_llm + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(LLMStructuredPredictStartEvent()) program = get_program_for_llm( output_cls, prompt, @@ -330,7 +336,9 @@ class LLM(BaseLLM): pydantic_program_mode=self.pydantic_program_mode, ) - return program(**prompt_args) + result = program(**prompt_args) + dispatch_event(LLMStructuredPredictEndEvent()) + return result @dispatcher.span async def astructured_predict( @@ -369,6 +377,10 @@ class LLM(BaseLLM): """ from llama_index.core.program.utils import get_program_for_llm + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(LLMStructuredPredictStartEvent()) + program = get_program_for_llm( output_cls, prompt, @@ -376,7 +388,9 @@ class LLM(BaseLLM): pydantic_program_mode=self.pydantic_program_mode, ) - return await program.acall(**prompt_args) + result = await program.acall(**prompt_args) + dispatch_event(LLMStructuredPredictEndEvent()) + return result # -- Prompt Chaining -- @@ -406,6 +420,9 @@ class LLM(BaseLLM): print(output) ``` """ + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(LLMPredictStartEvent()) self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: @@ -416,8 +433,7 @@ class LLM(BaseLLM): formatted_prompt = self._get_prompt(prompt, **prompt_args) response = self.complete(formatted_prompt, formatted=True) output = response.text - - dispatcher.event(LLMPredictEndEvent()) + dispatch_event(LLMPredictEndEvent()) return self._parse_output(output) @dispatcher.span @@ -489,6 +505,9 @@ class LLM(BaseLLM): print(output) ``` """ + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(LLMPredictStartEvent()) self._log_template_data(prompt, **prompt_args) if self.metadata.is_chat_model: @@ -500,7 +519,7 @@ class LLM(BaseLLM): response = await self.acomplete(formatted_prompt, formatted=True) output = response.text - dispatcher.event(LLMPredictEndEvent()) + dispatch_event(LLMPredictEndEvent()) return self._parse_output(output) @dispatcher.span diff --git a/llama-index-core/llama_index/core/response_synthesizers/base.py b/llama-index-core/llama_index/core/response_synthesizers/base.py index ebedd93aa6..e29f8bcb51 100644 --- a/llama-index-core/llama_index/core/response_synthesizers/base.py +++ b/llama-index-core/llama_index/core/response_synthesizers/base.py @@ -201,21 +201,33 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, **response_kwargs: Any, ) -> RESPONSE_TYPE: - dispatcher.event(SynthesizeStartEvent(query=query)) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + SynthesizeStartEvent( + query=query, + ) + ) if len(nodes) == 0: if self._streaming: empty_response = StreamingResponse( response_gen=empty_response_generator() ) - dispatcher.event( - SynthesizeEndEvent(query=query, response=empty_response) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=empty_response, + ) ) return empty_response else: empty_response = Response("Empty Response") - dispatcher.event( - SynthesizeEndEvent(query=query, response=empty_response) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=empty_response, + ) ) return empty_response @@ -223,7 +235,8 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): query = QueryBundle(query_str=query) with self._callback_manager.event( - CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str} + CBEventType.SYNTHESIZE, + payload={EventPayload.QUERY_STR: query.query_str}, ) as event: response_str = self.get_response( query_str=query.query_str, @@ -240,7 +253,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): event.on_end(payload={EventPayload.RESPONSE: response}) - dispatcher.event(SynthesizeEndEvent(query=query, response=response)) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=response, + ) + ) return response @dispatcher.span @@ -251,20 +269,32 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, **response_kwargs: Any, ) -> RESPONSE_TYPE: - dispatcher.event(SynthesizeStartEvent(query=query)) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + SynthesizeStartEvent( + query=query, + ) + ) if len(nodes) == 0: if self._streaming: empty_response = AsyncStreamingResponse( response_gen=empty_response_agenerator() ) - dispatcher.event( - SynthesizeEndEvent(query=query, response=empty_response) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=empty_response, + ) ) return empty_response else: empty_response = Response("Empty Response") - dispatcher.event( - SynthesizeEndEvent(query=query, response=empty_response) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=empty_response, + ) ) return empty_response @@ -272,7 +302,8 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): query = QueryBundle(query_str=query) with self._callback_manager.event( - CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str} + CBEventType.SYNTHESIZE, + payload={EventPayload.QUERY_STR: query.query_str}, ) as event: response_str = await self.aget_response( query_str=query.query_str, @@ -289,7 +320,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin): event.on_end(payload={EventPayload.RESPONSE: response}) - dispatcher.event(SynthesizeEndEvent(query=query, response=response)) + dispatch_event( + SynthesizeEndEvent( + query=query, + response=response, + ) + ) return response def _as_query_component(self, **kwargs: Any) -> QueryComponent: diff --git a/llama-index-core/llama_index/core/response_synthesizers/refine.py b/llama-index-core/llama_index/core/response_synthesizers/refine.py index 61bb4564bc..7b7f753266 100644 --- a/llama-index-core/llama_index/core/response_synthesizers/refine.py +++ b/llama-index-core/llama_index/core/response_synthesizers/refine.py @@ -172,7 +172,9 @@ class Refine(BaseSynthesizer): **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: """Give response over chunks.""" - dispatcher.event(GetResponseStartEvent()) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(GetResponseStartEvent()) response: Optional[RESPONSE_TEXT_TYPE] = None for text_chunk in text_chunks: if prev_response is None: @@ -194,7 +196,7 @@ class Refine(BaseSynthesizer): response = response or "Empty Response" else: response = cast(Generator, response) - dispatcher.event(GetResponseEndEvent()) + dispatch_event(GetResponseEndEvent()) return response def _default_program_factory(self, prompt: PromptTemplate) -> BasePydanticProgram: @@ -350,7 +352,9 @@ class Refine(BaseSynthesizer): prev_response: Optional[RESPONSE_TEXT_TYPE] = None, **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: - dispatcher.event(GetResponseStartEvent()) + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(GetResponseStartEvent()) response: Optional[RESPONSE_TEXT_TYPE] = None for text_chunk in text_chunks: if prev_response is None: @@ -373,6 +377,7 @@ class Refine(BaseSynthesizer): response = response or "Empty Response" else: response = cast(AsyncGenerator, response) + dispatch_event(GetResponseEndEvent()) return response async def _arefine_response_single( diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py index c256719b4b..f7269f11b3 100644 --- a/llama-index-core/tests/instrumentation/test_dispatcher.py +++ b/llama-index-core/tests/instrumentation/test_dispatcher.py @@ -1,9 +1,13 @@ +import asyncio import inspect from asyncio import CancelledError +from collections import Counter import pytest import llama_index.core.instrumentation as instrument from llama_index.core.instrumentation.dispatcher import Dispatcher +from llama_index.core.instrumentation.events import BaseEvent +from llama_index.core.instrumentation.event_handlers import BaseEventHandler from unittest.mock import patch, MagicMock dispatcher = instrument.get_dispatcher("test") @@ -12,43 +16,98 @@ value_error = ValueError("value error") cancelled_error = CancelledError("cancelled error") +class _TestStartEvent(BaseEvent): + @classmethod + def class_name(cls): + return "_TestStartEvent" + + +class _TestEndEvent(BaseEvent): + @classmethod + def class_name(cls): + return "_TestEndEvent" + + +class _TestEventHandler(BaseEventHandler): + events = [] + + @classmethod + def class_name(cls): + return "_TestEventHandler" + + def handle(self, e: BaseEvent): + self.events.append(e) + + @dispatcher.span -def func(*args, a, b=3, **kwargs): +def func(a, b=3, **kwargs): return a + b @dispatcher.span -async def async_func(*args, a, b=3, **kwargs): +async def async_func(a, b=3, **kwargs): return a + b @dispatcher.span -def func_exc(*args, a, b=3, c=4, **kwargs): +def func_exc(a, b=3, c=4, **kwargs): raise value_error @dispatcher.span -async def async_func_exc(*args, a, b=3, c=4, **kwargs): +async def async_func_exc(a, b=3, c=4, **kwargs): raise cancelled_error +@dispatcher.span +def func_with_event(a, b=3, **kwargs): + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(_TestStartEvent()) + + +@dispatcher.span +async def async_func_with_event(a, b=3, **kwargs): + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(_TestStartEvent()) + await asyncio.sleep(0.1) + dispatch_event(_TestEndEvent()) + + class _TestObject: @dispatcher.span - def func(self, *args, a, b=3, **kwargs): + def func(self, a, b=3, **kwargs): return a + b @dispatcher.span - async def async_func(self, *args, a, b=3, **kwargs): + async def async_func(self, a, b=3, **kwargs): return a + b @dispatcher.span - def func_exc(self, *args, a, b=3, c=4, **kwargs): + def func_exc(self, a, b=3, c=4, **kwargs): raise value_error @dispatcher.span - async def async_func_exc(self, *args, a, b=3, c=4, **kwargs): + async def async_func_exc(self, a, b=3, c=4, **kwargs): raise cancelled_error + @dispatcher.span + def func_with_event(self, a, b=3, **kwargs): + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(_TestStartEvent()) + + @dispatcher.span + async def async_func_with_event(self, a, b=3, **kwargs): + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event(_TestStartEvent()) + await asyncio.sleep(0.1) + await self.async_func(1) # this should create a new span_id + # that is fine because we have dispatch_event + dispatch_event(_TestEndEvent()) + @patch.object(Dispatcher, "span_exit") @patch.object(Dispatcher, "span_enter") @@ -58,12 +117,12 @@ def test_dispatcher_span_args(mock_uuid, mock_span_enter, mock_span_exit): mock_uuid.uuid4.return_value = "mock" # act - result = func(1, 2, a=3, c=5) + result = func(3, c=5) # assert # span_enter span_id = f"{func.__qualname__}-mock" - bound_args = inspect.signature(func).bind(1, 2, a=3, c=5) + bound_args = inspect.signature(func).bind(3, c=5) mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () @@ -89,12 +148,12 @@ def test_dispatcher_span_args_with_instance(mock_uuid, mock_span_enter, mock_spa # act instance = _TestObject() - result = instance.func(1, 2, a=3, c=5) + result = instance.func(3, c=5) # assert # span_enter span_id = f"{instance.func.__qualname__}-mock" - bound_args = inspect.signature(instance.func).bind(1, 2, a=3, c=5) + bound_args = inspect.signature(instance.func).bind(3, c=5) mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () @@ -126,7 +185,7 @@ def test_dispatcher_span_drop_args( with pytest.raises(ValueError): # act - _ = func_exc(7, a=3, b=5, c=2, d=5) + _ = func_exc(3, b=5, c=2, d=5) # assert # span_enter @@ -135,7 +194,7 @@ def test_dispatcher_span_drop_args( # span_drop mock_span_drop.assert_called_once() span_id = f"{func_exc.__qualname__}-mock" - bound_args = inspect.signature(func_exc).bind(7, a=3, b=5, c=2, d=5) + bound_args = inspect.signature(func_exc).bind(3, b=5, c=2, d=5) args, kwargs = mock_span_drop.call_args assert args == () assert kwargs == { @@ -165,7 +224,7 @@ def test_dispatcher_span_drop_args( with pytest.raises(ValueError): # act instance = _TestObject() - _ = instance.func_exc(7, a=3, b=5, c=2, d=5) + _ = instance.func_exc(a=3, b=5, c=2, d=5) # assert # span_enter @@ -174,7 +233,7 @@ def test_dispatcher_span_drop_args( # span_drop mock_span_drop.assert_called_once() span_id = f"{instance.func_exc.__qualname__}-mock" - bound_args = inspect.signature(instance.func_exc).bind(7, a=3, b=5, c=2, d=5) + bound_args = inspect.signature(instance.func_exc).bind(a=3, b=5, c=2, d=5) args, kwargs = mock_span_drop.call_args assert args == () assert kwargs == { @@ -197,12 +256,12 @@ async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_ mock_uuid.uuid4.return_value = "mock" # act - result = await async_func(1, 2, a=3, c=5) + result = await async_func(a=3, c=5) # assert # span_enter span_id = f"{async_func.__qualname__}-mock" - bound_args = inspect.signature(async_func).bind(1, 2, a=3, c=5) + bound_args = inspect.signature(async_func).bind(a=3, c=5) mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () @@ -231,12 +290,12 @@ async def test_dispatcher_async_span_args_with_instance( # act instance = _TestObject() - result = await instance.async_func(1, 2, a=3, c=5) + result = await instance.async_func(a=3, c=5) # assert # span_enter span_id = f"{instance.async_func.__qualname__}-mock" - bound_args = inspect.signature(instance.async_func).bind(1, 2, a=3, c=5) + bound_args = inspect.signature(instance.async_func).bind(a=3, c=5) mock_span_enter.assert_called_once() args, kwargs = mock_span_enter.call_args assert args == () @@ -269,7 +328,7 @@ async def test_dispatcher_async_span_drop_args( with pytest.raises(CancelledError): # act - _ = await async_func_exc(7, a=3, b=5, c=2, d=5) + _ = await async_func_exc(a=3, b=5, c=2, d=5) # assert # span_enter @@ -278,7 +337,7 @@ async def test_dispatcher_async_span_drop_args( # span_drop mock_span_drop.assert_called_once() span_id = f"{async_func_exc.__qualname__}-mock" - bound_args = inspect.signature(async_func_exc).bind(7, a=3, b=5, c=2, d=5) + bound_args = inspect.signature(async_func_exc).bind(a=3, b=5, c=2, d=5) args, kwargs = mock_span_drop.call_args assert args == () assert kwargs == { @@ -309,7 +368,7 @@ async def test_dispatcher_async_span_drop_args_with_instance( with pytest.raises(CancelledError): # act instance = _TestObject() - _ = await instance.async_func_exc(7, a=3, b=5, c=2, d=5) + _ = await instance.async_func_exc(a=3, b=5, c=2, d=5) # assert # span_enter @@ -318,7 +377,7 @@ async def test_dispatcher_async_span_drop_args_with_instance( # span_drop mock_span_drop.assert_called_once() span_id = f"{instance.async_func_exc.__qualname__}-mock" - bound_args = inspect.signature(instance.async_func_exc).bind(7, a=3, b=5, c=2, d=5) + bound_args = inspect.signature(instance.async_func_exc).bind(a=3, b=5, c=2, d=5) args, kwargs = mock_span_drop.call_args assert args == () assert kwargs == { @@ -330,3 +389,138 @@ async def test_dispatcher_async_span_drop_args_with_instance( # span_exit mock_span_exit.assert_not_called() + + +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +def test_dispatcher_fire_event( + mock_uuid: MagicMock, + mock_span_enter: MagicMock, + mock_span_drop: MagicMock, + mock_span_exit: MagicMock, +): + # arrange + mock_uuid.uuid4.return_value = "mock" + event_handler = _TestEventHandler() + dispatcher.add_event_handler(event_handler) + + # act + _ = func_with_event(3, c=5) + + # assert + span_id = f"{func_with_event.__qualname__}-mock" + assert all(e.span_id == span_id for e in event_handler.events) + + # span_enter + mock_span_enter.assert_called_once() + + # span + mock_span_drop.assert_not_called() + + # span_exit + mock_span_exit.assert_called_once() + + +@pytest.mark.asyncio() +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +async def test_dispatcher_async_fire_event( + mock_span_enter: MagicMock, + mock_span_drop: MagicMock, + mock_span_exit: MagicMock, +): + # arrange + event_handler = _TestEventHandler() + dispatcher.add_event_handler(event_handler) + + # act + tasks = [ + async_func_with_event(a=3, c=5), + async_func_with_event(5), + async_func_with_event(4), + ] + _ = await asyncio.gather(*tasks) + + # assert + span_ids = [e.span_id for e in event_handler.events] + id_counts = Counter(span_ids) + assert set(id_counts.values()) == {2} + + # span_enter + mock_span_enter.call_count == 3 + + # span + mock_span_drop.assert_not_called() + + # span_exit + mock_span_exit.call_count == 3 + + +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +def test_dispatcher_fire_event_with_instance( + mock_uuid, mock_span_enter, mock_span_drop, mock_span_exit +): + # arrange + mock_uuid.uuid4.return_value = "mock" + event_handler = _TestEventHandler() + dispatcher.add_event_handler(event_handler) + + # act + instance = _TestObject() + _ = instance.func_with_event(a=3, c=5) + + # assert + span_id = f"{instance.func_with_event.__qualname__}-mock" + assert all(e.span_id == span_id for e in event_handler.events) + + # span_enter + mock_span_enter.assert_called_once() + + # span + mock_span_drop.assert_not_called() + + # span_exit + mock_span_exit.assert_called_once() + + +@pytest.mark.asyncio() +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +async def test_dispatcher_async_fire_event_with_instance( + mock_span_enter: MagicMock, + mock_span_drop: MagicMock, + mock_span_exit: MagicMock, +): + # arrange + # mock_uuid.return_value = "mock" + event_handler = _TestEventHandler() + dispatcher.add_event_handler(event_handler) + + # act + instance = _TestObject() + tasks = [ + instance.async_func_with_event(a=3, c=5), + instance.async_func_with_event(5), + ] + _ = await asyncio.gather(*tasks) + + # assert + span_ids = [e.span_id for e in event_handler.events] + id_counts = Counter(span_ids) + assert set(id_counts.values()) == {2} + + # span_enter + mock_span_enter.call_count == 2 + + # span + mock_span_drop.assert_not_called() + + # span_exit + mock_span_exit.call_count == 2 diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD deleted file mode 100644 index dabf212d7e..0000000000 --- a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD +++ /dev/null @@ -1 +0,0 @@ -python_tests() diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py deleted file mode 100644 index c1d54abd7c..0000000000 --- a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py +++ /dev/null @@ -1,7 +0,0 @@ -from llama_index.callbacks.aim.base import AimCallback -from llama_index.core.callbacks.base_handler import BaseCallbackHandler - - -def test_class(): - names_of_base_classes = [b.__name__ for b in AimCallback.__mro__] - assert BaseCallbackHandler.__name__ in names_of_base_classes -- GitLab