From 95be930f5f715ba2268483b4c31b55c4c78e3e6d Mon Sep 17 00:00:00 2001
From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com>
Date: Tue, 2 Apr 2024 00:20:48 -0400
Subject: [PATCH] Add `span_id` attribute to Events (instrumentation) (#12417)

* add span_id to Event

* remove raise err in NullHandler

* wip

* modify root dispatcher event enclosing span

* remove *args as we have bound_args now

* add LLMChatInProgressEvent

* add LLMStructuredPredict Eventst

* store span_id before await executions

* add SpanDropEvent with err_str payload

* add event to _achat; flush current_span_id when open_spans is empty

* llm callbacks use root span_id

* add unit tests

* remove print statements

* provide context manager returning a distpatch event partial with correct span id

* move to context manager usage

* fix invocation of cm

* define and use get_dispatch_event method

* remove aim tests
---
 .../llama_index/core/agent/runner/base.py     |  23 +-
 .../core/base/base_query_engine.py            |  12 +-
 .../llama_index/core/base/base_retriever.py   |  30 ++-
 .../llama_index/core/base/embeddings/base.py  |  90 +++++--
 .../llama_index/core/chat_engine/types.py     |  26 +-
 .../core/instrumentation/__init__.py          |   5 +-
 .../core/instrumentation/dispatcher.py        |  92 ++++++-
 .../core/instrumentation/events/base.py       |   1 +
 .../core/instrumentation/events/llm.py        |  24 ++
 .../core/instrumentation/events/span.py       |  10 +
 .../core/instrumentation/span/simple.py       |   3 +-
 .../instrumentation/span_handlers/base.py     |  16 +-
 .../instrumentation/span_handlers/null.py     |   8 +-
 .../instrumentation/span_handlers/simple.py   |  18 +-
 .../llama_index/core/llms/callbacks.py        |  37 ++-
 llama-index-core/llama_index/core/llms/llm.py |  29 ++-
 .../core/response_synthesizers/base.py        |  64 ++++-
 .../core/response_synthesizers/refine.py      |  11 +-
 .../tests/instrumentation/test_dispatcher.py  | 242 ++++++++++++++++--
 .../llama-index-callbacks-aim/tests/BUILD     |   1 -
 .../tests/__init__.py                         |   0
 .../tests/test_aim_callback.py                |   7 -
 22 files changed, 617 insertions(+), 132 deletions(-)
 create mode 100644 llama-index-core/llama_index/core/instrumentation/events/span.py
 delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD
 delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py
 delete mode 100644 llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py

diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py
index c5759ffed4..2f17a796d6 100644
--- a/llama-index-core/llama_index/core/agent/runner/base.py
+++ b/llama-index-core/llama_index/core/agent/runner/base.py
@@ -365,7 +365,9 @@ class AgentRunner(BaseAgentRunner):
         **kwargs: Any,
     ) -> TaskStepOutput:
         """Execute step."""
-        dispatcher.event(AgentRunStepStartEvent())
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(AgentRunStepStartEvent())
         task = self.state.get_task(task_id)
         step_queue = self.state.get_step_queue(task_id)
         step = step or step_queue.popleft()
@@ -392,7 +394,7 @@ class AgentRunner(BaseAgentRunner):
         completed_steps = self.state.get_completed_steps(task_id)
         completed_steps.append(cur_step_output)
 
-        dispatcher.event(AgentRunStepEndEvent())
+        dispatch_event(AgentRunStepEndEvent())
         return cur_step_output
 
     @dispatcher.span
@@ -405,6 +407,9 @@ class AgentRunner(BaseAgentRunner):
         **kwargs: Any,
     ) -> TaskStepOutput:
         """Execute step."""
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(AgentRunStepStartEvent())
         task = self.state.get_task(task_id)
         step_queue = self.state.get_step_queue(task_id)
         step = step or step_queue.popleft()
@@ -430,6 +435,7 @@ class AgentRunner(BaseAgentRunner):
         completed_steps = self.state.get_completed_steps(task_id)
         completed_steps.append(cur_step_output)
 
+        dispatch_event(AgentRunStepEndEvent())
         return cur_step_output
 
     @dispatcher.span
@@ -528,12 +534,14 @@ class AgentRunner(BaseAgentRunner):
         mode: ChatResponseMode = ChatResponseMode.WAIT,
     ) -> AGENT_CHAT_RESPONSE_TYPE:
         """Chat with step executor."""
+        dispatch_event = dispatcher.get_dispatch_event()
+
         if chat_history is not None:
             self.memory.set(chat_history)
         task = self.create_task(message)
 
         result_output = None
-        dispatcher.event(AgentChatWithStepStartEvent())
+        dispatch_event(AgentChatWithStepStartEvent())
         while True:
             # pass step queue in as argument, assume step executor is stateless
             cur_step_output = self._run_step(
@@ -551,7 +559,7 @@ class AgentRunner(BaseAgentRunner):
             task.task_id,
             result_output,
         )
-        dispatcher.event(AgentChatWithStepEndEvent())
+        dispatch_event(AgentChatWithStepEndEvent())
         return result
 
     @dispatcher.span
@@ -563,11 +571,14 @@ class AgentRunner(BaseAgentRunner):
         mode: ChatResponseMode = ChatResponseMode.WAIT,
     ) -> AGENT_CHAT_RESPONSE_TYPE:
         """Chat with step executor."""
+        dispatch_event = dispatcher.get_dispatch_event()
+
         if chat_history is not None:
             self.memory.set(chat_history)
         task = self.create_task(message)
 
         result_output = None
+        dispatch_event(AgentChatWithStepStartEvent())
         while True:
             # pass step queue in as argument, assume step executor is stateless
             cur_step_output = await self._arun_step(
@@ -581,10 +592,12 @@ class AgentRunner(BaseAgentRunner):
             # ensure tool_choice does not cause endless loops
             tool_choice = "auto"
 
-        return self.finalize_response(
+        result = self.finalize_response(
             task.task_id,
             result_output,
         )
+        dispatch_event(AgentChatWithStepEndEvent())
+        return result
 
     @dispatcher.span
     @trace_method("chat")
diff --git a/llama-index-core/llama_index/core/base/base_query_engine.py b/llama-index-core/llama_index/core/base/base_query_engine.py
index b81fd17396..42f6b33e8c 100644
--- a/llama-index-core/llama_index/core/base/base_query_engine.py
+++ b/llama-index-core/llama_index/core/base/base_query_engine.py
@@ -44,22 +44,26 @@ class BaseQueryEngine(ChainableMixin, PromptMixin):
 
     @dispatcher.span
     def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
-        dispatcher.event(QueryStartEvent())
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(QueryStartEvent())
         with self.callback_manager.as_trace("query"):
             if isinstance(str_or_query_bundle, str):
                 str_or_query_bundle = QueryBundle(str_or_query_bundle)
             query_result = self._query(str_or_query_bundle)
-        dispatcher.event(QueryEndEvent())
+        dispatch_event(QueryEndEvent())
         return query_result
 
     @dispatcher.span
     async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
-        dispatcher.event(QueryStartEvent())
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(QueryStartEvent())
         with self.callback_manager.as_trace("query"):
             if isinstance(str_or_query_bundle, str):
                 str_or_query_bundle = QueryBundle(str_or_query_bundle)
             query_result = await self._aquery(str_or_query_bundle)
-        dispatcher.event(QueryEndEvent())
+        dispatch_event(QueryEndEvent())
         return query_result
 
     def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
diff --git a/llama-index-core/llama_index/core/base/base_retriever.py b/llama-index-core/llama_index/core/base/base_retriever.py
index 6dc7649de6..f87872b0ff 100644
--- a/llama-index-core/llama_index/core/base/base_retriever.py
+++ b/llama-index-core/llama_index/core/base/base_retriever.py
@@ -224,8 +224,14 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 a QueryBundle object.
 
         """
+        dispatch_event = dispatcher.get_dispatch_event()
+
         self._check_callback_manager()
-        dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
+        dispatch_event(
+            RetrievalStartEvent(
+                str_or_query_bundle=str_or_query_bundle,
+            )
+        )
         if isinstance(str_or_query_bundle, str):
             query_bundle = QueryBundle(str_or_query_bundle)
         else:
@@ -240,15 +246,24 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 retrieve_event.on_end(
                     payload={EventPayload.NODES: nodes},
                 )
-        dispatcher.event(
-            RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
+        dispatch_event(
+            RetrievalEndEvent(
+                str_or_query_bundle=str_or_query_bundle,
+                nodes=nodes,
+            )
         )
         return nodes
 
     @dispatcher.span
     async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
         self._check_callback_manager()
-        dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            RetrievalStartEvent(
+                str_or_query_bundle=str_or_query_bundle,
+            )
+        )
         if isinstance(str_or_query_bundle, str):
             query_bundle = QueryBundle(str_or_query_bundle)
         else:
@@ -265,8 +280,11 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 retrieve_event.on_end(
                     payload={EventPayload.NODES: nodes},
                 )
-        dispatcher.event(
-            RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
+        dispatch_event(
+            RetrievalEndEvent(
+                str_or_query_bundle=str_or_query_bundle,
+                nodes=nodes,
+            )
         )
         return nodes
 
diff --git a/llama-index-core/llama_index/core/base/embeddings/base.py b/llama-index-core/llama_index/core/base/embeddings/base.py
index 5924aca1a3..3f5c0cf849 100644
--- a/llama-index-core/llama_index/core/base/embeddings/base.py
+++ b/llama-index-core/llama_index/core/base/embeddings/base.py
@@ -114,7 +114,13 @@ class BaseEmbedding(TransformComponent):
         other examples of predefined instructions can be found in
         embeddings/huggingface_utils.py.
         """
-        dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            EmbeddingStartEvent(
+                model_dict=self.to_dict(),
+            )
+        )
         with self.callback_manager.event(
             CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
         ) as event:
@@ -126,15 +132,24 @@ class BaseEmbedding(TransformComponent):
                     EventPayload.EMBEDDINGS: [query_embedding],
                 },
             )
-        dispatcher.event(
-            EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding])
+        dispatch_event(
+            EmbeddingEndEvent(
+                chunks=[query],
+                embeddings=[query_embedding],
+            )
         )
         return query_embedding
 
     @dispatcher.span
     async def aget_query_embedding(self, query: str) -> Embedding:
         """Get query embedding."""
-        dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            EmbeddingStartEvent(
+                model_dict=self.to_dict(),
+            )
+        )
         with self.callback_manager.event(
             CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
         ) as event:
@@ -146,8 +161,11 @@ class BaseEmbedding(TransformComponent):
                     EventPayload.EMBEDDINGS: [query_embedding],
                 },
             )
-        dispatcher.event(
-            EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding])
+        dispatch_event(
+            EmbeddingEndEvent(
+                chunks=[query],
+                embeddings=[query_embedding],
+            )
         )
         return query_embedding
 
@@ -220,7 +238,13 @@ class BaseEmbedding(TransformComponent):
         document for retrieval: ". If you're curious, other examples of
         predefined instructions can be found in embeddings/huggingface_utils.py.
         """
-        dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            EmbeddingStartEvent(
+                model_dict=self.to_dict(),
+            )
+        )
         with self.callback_manager.event(
             CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
         ) as event:
@@ -232,13 +256,24 @@ class BaseEmbedding(TransformComponent):
                     EventPayload.EMBEDDINGS: [text_embedding],
                 }
             )
-        dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding]))
+        dispatch_event(
+            EmbeddingEndEvent(
+                chunks=[text],
+                embeddings=[text_embedding],
+            )
+        )
         return text_embedding
 
     @dispatcher.span
     async def aget_text_embedding(self, text: str) -> Embedding:
         """Async get text embedding."""
-        dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            EmbeddingStartEvent(
+                model_dict=self.to_dict(),
+            )
+        )
         with self.callback_manager.event(
             CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
         ) as event:
@@ -250,7 +285,12 @@ class BaseEmbedding(TransformComponent):
                     EventPayload.EMBEDDINGS: [text_embedding],
                 }
             )
-        dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding]))
+        dispatch_event(
+            EmbeddingEndEvent(
+                chunks=[text],
+                embeddings=[text_embedding],
+            )
+        )
         return text_embedding
 
     @dispatcher.span
@@ -261,6 +301,8 @@ class BaseEmbedding(TransformComponent):
         **kwargs: Any,
     ) -> List[Embedding]:
         """Get a list of text embeddings, with batching."""
+        dispatch_event = dispatcher.get_dispatch_event()
+
         cur_batch: List[str] = []
         result_embeddings: List[Embedding] = []
 
@@ -272,7 +314,11 @@ class BaseEmbedding(TransformComponent):
             cur_batch.append(text)
             if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
                 # flush
-                dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+                dispatch_event(
+                    EmbeddingStartEvent(
+                        model_dict=self.to_dict(),
+                    )
+                )
                 with self.callback_manager.event(
                     CBEventType.EMBEDDING,
                     payload={EventPayload.SERIALIZED: self.to_dict()},
@@ -285,8 +331,11 @@ class BaseEmbedding(TransformComponent):
                             EventPayload.EMBEDDINGS: embeddings,
                         },
                     )
-                dispatcher.event(
-                    EmbeddingEndEvent(chunks=cur_batch, embeddings=embeddings)
+                dispatch_event(
+                    EmbeddingEndEvent(
+                        chunks=cur_batch,
+                        embeddings=embeddings,
+                    )
                 )
                 cur_batch = []
 
@@ -297,6 +346,8 @@ class BaseEmbedding(TransformComponent):
         self, texts: List[str], show_progress: bool = False
     ) -> List[Embedding]:
         """Asynchronously get a list of text embeddings, with batching."""
+        dispatch_event = dispatcher.get_dispatch_event()
+
         cur_batch: List[str] = []
         callback_payloads: List[Tuple[str, List[str]]] = []
         result_embeddings: List[Embedding] = []
@@ -305,7 +356,11 @@ class BaseEmbedding(TransformComponent):
             cur_batch.append(text)
             if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
                 # flush
-                dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
+                dispatch_event(
+                    EmbeddingStartEvent(
+                        model_dict=self.to_dict(),
+                    )
+                )
                 event_id = self.callback_manager.on_event_start(
                     CBEventType.EMBEDDING,
                     payload={EventPayload.SERIALIZED: self.to_dict()},
@@ -337,8 +392,11 @@ class BaseEmbedding(TransformComponent):
         for (event_id, text_batch), embeddings in zip(
             callback_payloads, nested_embeddings
         ):
-            dispatcher.event(
-                EmbeddingEndEvent(chunks=text_batch, embeddings=embeddings)
+            dispatch_event(
+                EmbeddingEndEvent(
+                    chunks=text_batch,
+                    embeddings=embeddings,
+                )
             )
             self.callback_manager.on_event_end(
                 CBEventType.EMBEDDING,
diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py
index 9291565269..72753cb92e 100644
--- a/llama-index-core/llama_index/core/chat_engine/types.py
+++ b/llama-index-core/llama_index/core/chat_engine/types.py
@@ -127,15 +127,20 @@ class StreamingAgentChatResponse:
             raise ValueError(
                 "chat_stream is None. Cannot write to history without chat_stream."
             )
+        dispatch_event = dispatcher.get_dispatch_event()
 
         # try/except to prevent hanging on error
-        dispatcher.event(StreamChatStartEvent())
+        dispatch_event(StreamChatStartEvent())
         try:
             final_text = ""
             for chat in self.chat_stream:
                 self._is_function = is_function(chat.message)
                 if chat.delta:
-                    dispatcher.event(StreamChatDeltaReceivedEvent(delta=chat.delta))
+                    dispatch_event(
+                        StreamChatDeltaReceivedEvent(
+                            delta=chat.delta,
+                        )
+                    )
                     self.put_in_queue(chat.delta)
                 final_text += chat.delta or ""
             if self._is_function is not None:  # if loop has gone through iteration
@@ -144,14 +149,14 @@ class StreamingAgentChatResponse:
                 chat.message.content = final_text.strip()  # final message
                 memory.put(chat.message)
         except Exception as e:
-            dispatcher.event(StreamChatErrorEvent())
+            dispatch_event(StreamChatErrorEvent())
             if not raise_error:
                 logger.warning(
                     f"Encountered exception writing response to history: {e}"
                 )
             else:
                 raise
-        dispatcher.event(StreamChatEndEvent())
+        dispatch_event(StreamChatEndEvent())
 
         self._is_done = True
 
@@ -167,6 +172,7 @@ class StreamingAgentChatResponse:
         on_stream_end_fn: Optional[callable] = None,
     ) -> None:
         self._ensure_async_setup()
+        dispatch_event = dispatcher.get_dispatch_event()
 
         if self.achat_stream is None:
             raise ValueError(
@@ -175,13 +181,17 @@ class StreamingAgentChatResponse:
             )
 
         # try/except to prevent hanging on error
-        dispatcher.event(StreamChatStartEvent())
+        dispatch_event(StreamChatStartEvent())
         try:
             final_text = ""
             async for chat in self.achat_stream:
                 self._is_function = is_function(chat.message)
                 if chat.delta:
-                    dispatcher.event(StreamChatDeltaReceivedEvent(delta=chat.delta))
+                    dispatch_event(
+                        StreamChatDeltaReceivedEvent(
+                            delta=chat.delta,
+                        )
+                    )
                     self.aput_in_queue(chat.delta)
                 final_text += chat.delta or ""
                 self._new_item_event.set()
@@ -193,9 +203,9 @@ class StreamingAgentChatResponse:
                 chat.message.content = final_text.strip()  # final message
                 memory.put(chat.message)
         except Exception as e:
-            dispatcher.event(StreamChatErrorEvent())
+            dispatch_event(StreamChatErrorEvent())
             logger.warning(f"Encountered exception writing response to history: {e}")
-        dispatcher.event(StreamChatEndEvent())
+        dispatch_event(StreamChatEndEvent())
         self._is_done = True
 
         # These act as is_done events for any consumers waiting
diff --git a/llama-index-core/llama_index/core/instrumentation/__init__.py b/llama-index-core/llama_index/core/instrumentation/__init__.py
index b7564a02f7..fd41c2ec17 100644
--- a/llama-index-core/llama_index/core/instrumentation/__init__.py
+++ b/llama-index-core/llama_index/core/instrumentation/__init__.py
@@ -24,7 +24,10 @@ def get_dispatcher(name: str = "root") -> Dispatcher:
         parent_name = "root"
 
     new_dispatcher = Dispatcher(
-        name=name, root=root_dispatcher, parent_name=parent_name, manager=root_manager
+        name=name,
+        root_name=root_dispatcher.name,
+        parent_name=parent_name,
+        manager=root_manager,
     )
     root_manager.add_dispatcher(new_dispatcher)
     return new_dispatcher
diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
index 727d7cba82..c4b659aa06 100644
--- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py
+++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
@@ -1,16 +1,33 @@
-from typing import Any, List, Optional, Dict
+from typing import Any, List, Optional, Dict, Protocol
+from functools import partial
+from contextlib import contextmanager
+import asyncio
 import inspect
 import uuid
-from llama_index.core.bridge.pydantic import BaseModel, Field
+from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr
+from llama_index.core.instrumentation.events import BaseEvent
 from llama_index.core.instrumentation.event_handlers import BaseEventHandler
 from llama_index.core.instrumentation.span_handlers import (
     BaseSpanHandler,
     NullSpanHandler,
 )
 from llama_index.core.instrumentation.events.base import BaseEvent
+from llama_index.core.instrumentation.events.span import SpanDropEvent
 import wrapt
 
 
+class EventDispatcher(Protocol):
+    def __call__(self, event: BaseEvent) -> None:
+        ...
+
+
+class EventContext(BaseModel):
+    span_id: str = Field(default="")
+
+
+event_context = EventContext()
+
+
 class Dispatcher(BaseModel):
     name: str = Field(default_factory=str, description="Name of dispatcher")
     event_handlers: List[BaseEventHandler] = Field(
@@ -30,6 +47,31 @@ class Dispatcher(BaseModel):
         default=True,
         description="Whether to propagate the event to parent dispatchers and their handlers",
     )
+    current_span_id: Optional[str] = Field(
+        default=None, description="Id of current span."
+    )
+    _asyncio_lock: asyncio.Lock = PrivateAttr()
+
+    def __init__(
+        self,
+        name: str = "",
+        event_handlers: List[BaseEventHandler] = [],
+        span_handlers: List[BaseSpanHandler] = [],
+        parent_name: str = "",
+        manager: Optional["Manager"] = None,
+        root_name: str = "root",
+        propagate: bool = True,
+    ):
+        self._asyncio_lock = asyncio.Lock()
+        super().__init__(
+            name=name,
+            event_handlers=event_handlers,
+            span_handlers=span_handlers,
+            parent_name=parent_name,
+            manager=manager,
+            root_name=root_name,
+            propagate=propagate,
+        )
 
     @property
     def parent(self) -> "Dispatcher":
@@ -47,9 +89,11 @@ class Dispatcher(BaseModel):
         """Add handler to set of handlers."""
         self.span_handlers += [handler]
 
-    def event(self, event: BaseEvent, **kwargs) -> None:
+    def event(self, event: BaseEvent, span_id: Optional[str] = None, **kwargs) -> None:
         """Dispatch event to all registered handlers."""
         c = self
+        if span_id:
+            event.span_id = span_id
         while c:
             for h in c.event_handlers:
                 h.handle(event, **kwargs)
@@ -60,7 +104,6 @@ class Dispatcher(BaseModel):
 
     def span_enter(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -71,7 +114,6 @@ class Dispatcher(BaseModel):
         while c:
             for h in c.span_handlers:
                 h.span_enter(
-                    *args,
                     id_=id_,
                     bound_args=bound_args,
                     instance=instance,
@@ -84,7 +126,6 @@ class Dispatcher(BaseModel):
 
     def span_drop(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -96,7 +137,6 @@ class Dispatcher(BaseModel):
         while c:
             for h in c.span_handlers:
                 h.span_drop(
-                    *args,
                     id_=id_,
                     bound_args=bound_args,
                     instance=instance,
@@ -110,7 +150,6 @@ class Dispatcher(BaseModel):
 
     def span_exit(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -122,7 +161,6 @@ class Dispatcher(BaseModel):
         while c:
             for h in c.span_handlers:
                 h.span_exit(
-                    *args,
                     id_=id_,
                     bound_args=bound_args,
                     instance=instance,
@@ -134,15 +172,45 @@ class Dispatcher(BaseModel):
             else:
                 c = c.parent
 
+    def get_dispatch_event(self) -> EventDispatcher:
+        """Get dispatch_event for firing events within the context of a span.
+
+        This method should be used with @dispatcher.span decorated
+        functions only. Otherwise, the span_id should not be trusted, as the
+        span decorator sets the span_id.
+        """
+        span_id = self.current_span_id
+        dispatch_event: EventDispatcher = partial(self.event, span_id=span_id)
+        return dispatch_event
+
+    @contextmanager
+    def dispatch_event(self):
+        """Context manager for firing events within a span session.
+
+        This context manager should be used with @dispatcher.span decorated
+        functions only. Otherwise, the span_id should not be trusted, as the
+        span decorator sets the span_id.
+        """
+        span_id = self.current_span_id
+        dispatch_event: EventDispatcher = partial(self.event, span_id=span_id)
+
+        try:
+            yield dispatch_event
+        finally:
+            del dispatch_event
+
     def span(self, func):
         @wrapt.decorator
         def wrapper(func, instance, args, kwargs):
             bound_args = inspect.signature(func).bind(*args, **kwargs)
             id_ = f"{func.__qualname__}-{uuid.uuid4()}"
+            self.current_span_id = id_
+            self.root.current_span_id = id_
             self.span_enter(id_=id_, bound_args=bound_args, instance=instance)
             try:
                 result = func(*args, **kwargs)
             except BaseException as e:
+                self.event(SpanDropEvent(span_id=id_, err_str=str(e)))
                 self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e)
                 raise
             else:
@@ -155,10 +223,16 @@ class Dispatcher(BaseModel):
         async def async_wrapper(func, instance, args, kwargs):
             bound_args = inspect.signature(func).bind(*args, **kwargs)
             id_ = f"{func.__qualname__}-{uuid.uuid4()}"
+            async with self._asyncio_lock:
+                self.current_span_id = id_
+            async with self.root._asyncio_lock:
+                self.root.current_span_id = id_
+
             self.span_enter(id_=id_, bound_args=bound_args, instance=instance)
             try:
                 result = await func(*args, **kwargs)
             except BaseException as e:
+                self.event(SpanDropEvent(span_id=id_, err_str=str(e)))
                 self.span_drop(id_=id_, bound_args=bound_args, instance=instance, err=e)
                 raise
             else:
diff --git a/llama-index-core/llama_index/core/instrumentation/events/base.py b/llama-index-core/llama_index/core/instrumentation/events/base.py
index 061aea2605..28f3d1dba7 100644
--- a/llama-index-core/llama_index/core/instrumentation/events/base.py
+++ b/llama-index-core/llama_index/core/instrumentation/events/base.py
@@ -7,6 +7,7 @@ from datetime import datetime
 class BaseEvent(BaseModel):
     timestamp: datetime = Field(default_factory=lambda: datetime.now())
     id_: str = Field(default_factory=lambda: uuid4())
+    span_id: str = Field(default_factory=str)
 
     @classmethod
     def class_name(cls):
diff --git a/llama-index-core/llama_index/core/instrumentation/events/llm.py b/llama-index-core/llama_index/core/instrumentation/events/llm.py
index b61dde1b84..28a81204a5 100644
--- a/llama-index-core/llama_index/core/instrumentation/events/llm.py
+++ b/llama-index-core/llama_index/core/instrumentation/events/llm.py
@@ -22,6 +22,20 @@ class LLMPredictEndEvent(BaseEvent):
         return "LLMPredictEndEvent"
 
 
+class LLMStructuredPredictStartEvent(BaseEvent):
+    @classmethod
+    def class_name(cls):
+        """Class name."""
+        return "LLMStructuredPredictStartEvent"
+
+
+class LLMStructuredPredictEndEvent(BaseEvent):
+    @classmethod
+    def class_name(cls):
+        """Class name."""
+        return "LLMStructuredPredictEndEvent"
+
+
 class LLMCompletionStartEvent(BaseEvent):
     prompt: str
     additional_kwargs: dict
@@ -54,6 +68,16 @@ class LLMChatStartEvent(BaseEvent):
         return "LLMChatStartEvent"
 
 
+class LLMChatInProgressEvent(BaseEvent):
+    messages: List[ChatMessage]
+    response: ChatResponse
+
+    @classmethod
+    def class_name(cls):
+        """Class name."""
+        return "LLMChatInProgressEvent"
+
+
 class LLMChatEndEvent(BaseEvent):
     messages: List[ChatMessage]
     response: ChatResponse
diff --git a/llama-index-core/llama_index/core/instrumentation/events/span.py b/llama-index-core/llama_index/core/instrumentation/events/span.py
new file mode 100644
index 0000000000..113932e5b0
--- /dev/null
+++ b/llama-index-core/llama_index/core/instrumentation/events/span.py
@@ -0,0 +1,10 @@
+from llama_index.core.instrumentation.events.base import BaseEvent
+
+
+class SpanDropEvent(BaseEvent):
+    err_str: str
+
+    @classmethod
+    def class_name(cls):
+        """Class name."""
+        return "SpanDropEvent"
diff --git a/llama-index-core/llama_index/core/instrumentation/span/simple.py b/llama-index-core/llama_index/core/instrumentation/span/simple.py
index ceef9338e1..7905b743f2 100644
--- a/llama-index-core/llama_index/core/instrumentation/span/simple.py
+++ b/llama-index-core/llama_index/core/instrumentation/span/simple.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Dict, Optional
 from llama_index.core.bridge.pydantic import Field
 from llama_index.core.instrumentation.span.base import BaseSpan
 from datetime import datetime
@@ -10,3 +10,4 @@ class SimpleSpan(BaseSpan):
     start_time: datetime = Field(default_factory=lambda: datetime.now())
     end_time: Optional[datetime] = Field(default=None)
     duration: float = Field(default=float, description="Duration of span in seconds.")
+    metadata: Optional[Dict] = Field(default=None)
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
index 78beb7c295..725cfaa41f 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
@@ -1,6 +1,6 @@
 import inspect
 from abc import abstractmethod
-from typing import Any, Dict, Generic, Optional, TypeVar
+from typing import Any, Dict, List, Generic, Optional, TypeVar
 
 from llama_index.core.bridge.pydantic import BaseModel, Field
 from llama_index.core.instrumentation.span.base import BaseSpan
@@ -12,6 +12,12 @@ class BaseSpanHandler(BaseModel, Generic[T]):
     open_spans: Dict[str, T] = Field(
         default_factory=dict, description="Dictionary of open spans."
     )
+    completed_spans: List[T] = Field(
+        default_factory=list, description="List of completed spans."
+    )
+    dropped_spans: List[T] = Field(
+        default_factory=list, description="List of completed spans."
+    )
     current_span_id: Optional[str] = Field(
         default=None, description="Id of current span."
     )
@@ -25,7 +31,6 @@ class BaseSpanHandler(BaseModel, Generic[T]):
 
     def span_enter(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -48,7 +53,6 @@ class BaseSpanHandler(BaseModel, Generic[T]):
 
     def span_exit(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -63,10 +67,11 @@ class BaseSpanHandler(BaseModel, Generic[T]):
             if self.current_span_id == id_:
                 self.current_span_id = self.open_spans[id_].parent_id
             del self.open_spans[id_]
+        if not self.open_spans:  # empty so flush
+            self.current_span_id = None
 
     def span_drop(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -85,7 +90,6 @@ class BaseSpanHandler(BaseModel, Generic[T]):
     @abstractmethod
     def new_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -98,7 +102,6 @@ class BaseSpanHandler(BaseModel, Generic[T]):
     @abstractmethod
     def prepare_to_exit_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -111,7 +114,6 @@ class BaseSpanHandler(BaseModel, Generic[T]):
     @abstractmethod
     def prepare_to_drop_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
index a4c6000ebc..02788db0e0 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
@@ -12,7 +12,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
 
     def span_enter(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -23,7 +22,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
 
     def span_exit(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -35,7 +33,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
 
     def new_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -47,7 +44,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
 
     def prepare_to_exit_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -59,7 +55,6 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
 
     def prepare_to_drop_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -67,5 +62,4 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
         **kwargs: Any
     ) -> None:
         """Logic for droppping a span."""
-        if err:
-            raise err
+        return
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
index 18f9bddc4f..f0ae9b0fab 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
@@ -1,6 +1,5 @@
 import inspect
 from typing import Any, cast, List, Optional, TYPE_CHECKING
-from llama_index.core.bridge.pydantic import Field
 from llama_index.core.instrumentation.span.simple import SimpleSpan
 from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler
 from datetime import datetime
@@ -13,17 +12,12 @@ if TYPE_CHECKING:
 class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
     """Span Handler that managest SimpleSpan's."""
 
-    completed_spans: List[SimpleSpan] = Field(
-        default_factory=list, description="List of completed spans."
-    )
-
     def class_name(cls) -> str:
         """Class name."""
         return "SimpleSpanHandler"
 
     def new_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -35,7 +29,6 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
 
     def prepare_to_exit_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -52,7 +45,6 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
 
     def prepare_to_drop_span(
         self,
-        *args: Any,
         id_: str,
         bound_args: inspect.BoundArguments,
         instance: Optional[Any] = None,
@@ -61,7 +53,11 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
     ) -> SimpleSpan:
         """Logic for droppping a span."""
         if id_ in self.open_spans:
-            return self.open_spans[id_]
+            span = self.open_spans[id_]
+            span.metadata = {"error": str(err)}
+            self.dropped_spans += [span]
+            return span
+
         return None
 
     def _get_trace_trees(self) -> List["Tree"]:
@@ -74,7 +70,9 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
                 "`treelib` package is missing. Please install it by using "
                 "`pip install treelib`."
             )
-        sorted_spans = sorted(self.completed_spans, key=lambda x: x.start_time)
+        sorted_spans = sorted(
+            self.completed_spans + self.dropped_spans, key=lambda x: x.start_time
+        )
 
         trees = []
         tree = Tree()
diff --git a/llama-index-core/llama_index/core/llms/callbacks.py b/llama-index-core/llama_index/core/llms/callbacks.py
index e3602007d1..1b9075ee6e 100644
--- a/llama-index-core/llama_index/core/llms/callbacks.py
+++ b/llama-index-core/llama_index/core/llms/callbacks.py
@@ -27,6 +27,7 @@ from llama_index.core.instrumentation.events.llm import (
     LLMCompletionStartEvent,
     LLMChatEndEvent,
     LLMChatStartEvent,
+    LLMChatInProgressEvent,
 )
 
 dispatcher = get_dispatcher(__name__)
@@ -49,11 +50,13 @@ def llm_chat_callback() -> Callable:
             _self: Any, messages: Sequence[ChatMessage], **kwargs: Any
         ) -> Any:
             with wrapper_logic(_self) as callback_manager:
+                span_id = dispatcher.root.current_span_id or ""
                 dispatcher.event(
                     LLMChatStartEvent(
                         model_dict=_self.to_dict(),
                         messages=messages,
                         additional_kwargs=kwargs,
+                        span_id=span_id,
                     )
                 )
                 event_id = callback_manager.on_event_start(
@@ -72,9 +75,10 @@ def llm_chat_callback() -> Callable:
                         last_response = None
                         async for x in f_return_val:
                             dispatcher.event(
-                                LLMChatEndEvent(
+                                LLMChatInProgressEvent(
                                     messages=messages,
                                     response=x,
+                                    span_id=span_id,
                                 )
                             )
                             yield cast(ChatResponse, x)
@@ -88,6 +92,13 @@ def llm_chat_callback() -> Callable:
                             },
                             event_id=event_id,
                         )
+                        dispatcher.event(
+                            LLMChatEndEvent(
+                                messages=messages,
+                                response=x,
+                                span_id=span_id,
+                            )
+                        )
 
                     return wrapped_gen()
                 else:
@@ -103,6 +114,7 @@ def llm_chat_callback() -> Callable:
                         LLMChatEndEvent(
                             messages=messages,
                             response=f_return_val,
+                            span_id=span_id,
                         )
                     )
 
@@ -112,11 +124,13 @@ def llm_chat_callback() -> Callable:
             _self: Any, messages: Sequence[ChatMessage], **kwargs: Any
         ) -> Any:
             with wrapper_logic(_self) as callback_manager:
+                span_id = dispatcher.root.current_span_id or ""
                 dispatcher.event(
                     LLMChatStartEvent(
                         model_dict=_self.to_dict(),
                         messages=messages,
                         additional_kwargs=kwargs,
+                        span_id=span_id,
                     )
                 )
                 event_id = callback_manager.on_event_start(
@@ -135,9 +149,10 @@ def llm_chat_callback() -> Callable:
                         last_response = None
                         for x in f_return_val:
                             dispatcher.event(
-                                LLMChatEndEvent(
+                                LLMChatInProgressEvent(
                                     messages=messages,
                                     response=x,
+                                    span_id=span_id,
                                 )
                             )
                             yield cast(ChatResponse, x)
@@ -151,6 +166,13 @@ def llm_chat_callback() -> Callable:
                             },
                             event_id=event_id,
                         )
+                        dispatcher.event(
+                            LLMChatEndEvent(
+                                messages=messages,
+                                response=x,
+                                span_id=span_id,
+                            )
+                        )
 
                     return wrapped_gen()
                 else:
@@ -166,6 +188,7 @@ def llm_chat_callback() -> Callable:
                         LLMChatEndEvent(
                             messages=messages,
                             response=f_return_val,
+                            span_id=span_id,
                         )
                     )
 
@@ -213,11 +236,13 @@ def llm_completion_callback() -> Callable:
             _self: Any, *args: Any, **kwargs: Any
         ) -> Any:
             with wrapper_logic(_self) as callback_manager:
+                span_id = dispatcher.root.current_span_id or ""
                 dispatcher.event(
                     LLMCompletionStartEvent(
                         model_dict=_self.to_dict(),
                         prompt=str(args[0]),
                         additional_kwargs=kwargs,
+                        span_id=span_id,
                     )
                 )
                 event_id = callback_manager.on_event_start(
@@ -240,6 +265,7 @@ def llm_completion_callback() -> Callable:
                                 LLMCompletionEndEvent(
                                     prompt=str(args[0]),
                                     response=x,
+                                    span_id=span_id,
                                 )
                             )
                             yield cast(CompletionResponse, x)
@@ -268,6 +294,7 @@ def llm_completion_callback() -> Callable:
                         LLMCompletionEndEvent(
                             prompt=str(args[0]),
                             response=f_return_val,
+                            span_id=span_id,
                         )
                     )
 
@@ -275,11 +302,13 @@ def llm_completion_callback() -> Callable:
 
         def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
             with wrapper_logic(_self) as callback_manager:
+                span_id = dispatcher.root.current_span_id or ""
                 dispatcher.event(
                     LLMCompletionStartEvent(
                         model_dict=_self.to_dict(),
                         prompt=str(args[0]),
                         additional_kwargs=kwargs,
+                        span_id=span_id,
                     )
                 )
                 event_id = callback_manager.on_event_start(
@@ -299,8 +328,7 @@ def llm_completion_callback() -> Callable:
                         for x in f_return_val:
                             dispatcher.event(
                                 LLMCompletionEndEvent(
-                                    prompt=str(args[0]),
-                                    response=x,
+                                    prompt=str(args[0]), response=x, span_id=span_id
                                 )
                             )
                             yield cast(CompletionResponse, x)
@@ -329,6 +357,7 @@ def llm_completion_callback() -> Callable:
                         LLMCompletionEndEvent(
                             prompt=str(args[0]),
                             response=f_return_val,
+                            span_id=span_id,
                         )
                     )
 
diff --git a/llama-index-core/llama_index/core/llms/llm.py b/llama-index-core/llama_index/core/llms/llm.py
index 678d9dfcea..b8349ea892 100644
--- a/llama-index-core/llama_index/core/llms/llm.py
+++ b/llama-index-core/llama_index/core/llms/llm.py
@@ -51,6 +51,9 @@ from llama_index.core.types import (
 )
 from llama_index.core.instrumentation.events.llm import (
     LLMPredictEndEvent,
+    LLMPredictStartEvent,
+    LLMStructuredPredictEndEvent,
+    LLMStructuredPredictStartEvent,
 )
 
 import llama_index.core.instrumentation as instrument
@@ -323,6 +326,9 @@ class LLM(BaseLLM):
         """
         from llama_index.core.program.utils import get_program_for_llm
 
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(LLMStructuredPredictStartEvent())
         program = get_program_for_llm(
             output_cls,
             prompt,
@@ -330,7 +336,9 @@ class LLM(BaseLLM):
             pydantic_program_mode=self.pydantic_program_mode,
         )
 
-        return program(**prompt_args)
+        result = program(**prompt_args)
+        dispatch_event(LLMStructuredPredictEndEvent())
+        return result
 
     @dispatcher.span
     async def astructured_predict(
@@ -369,6 +377,10 @@ class LLM(BaseLLM):
         """
         from llama_index.core.program.utils import get_program_for_llm
 
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(LLMStructuredPredictStartEvent())
+
         program = get_program_for_llm(
             output_cls,
             prompt,
@@ -376,7 +388,9 @@ class LLM(BaseLLM):
             pydantic_program_mode=self.pydantic_program_mode,
         )
 
-        return await program.acall(**prompt_args)
+        result = await program.acall(**prompt_args)
+        dispatch_event(LLMStructuredPredictEndEvent())
+        return result
 
     # -- Prompt Chaining --
 
@@ -406,6 +420,9 @@ class LLM(BaseLLM):
             print(output)
             ```
         """
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(LLMPredictStartEvent())
         self._log_template_data(prompt, **prompt_args)
 
         if self.metadata.is_chat_model:
@@ -416,8 +433,7 @@ class LLM(BaseLLM):
             formatted_prompt = self._get_prompt(prompt, **prompt_args)
             response = self.complete(formatted_prompt, formatted=True)
             output = response.text
-
-        dispatcher.event(LLMPredictEndEvent())
+        dispatch_event(LLMPredictEndEvent())
         return self._parse_output(output)
 
     @dispatcher.span
@@ -489,6 +505,9 @@ class LLM(BaseLLM):
             print(output)
             ```
         """
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(LLMPredictStartEvent())
         self._log_template_data(prompt, **prompt_args)
 
         if self.metadata.is_chat_model:
@@ -500,7 +519,7 @@ class LLM(BaseLLM):
             response = await self.acomplete(formatted_prompt, formatted=True)
             output = response.text
 
-        dispatcher.event(LLMPredictEndEvent())
+        dispatch_event(LLMPredictEndEvent())
         return self._parse_output(output)
 
     @dispatcher.span
diff --git a/llama-index-core/llama_index/core/response_synthesizers/base.py b/llama-index-core/llama_index/core/response_synthesizers/base.py
index ebedd93aa6..e29f8bcb51 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/base.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/base.py
@@ -201,21 +201,33 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
         additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
         **response_kwargs: Any,
     ) -> RESPONSE_TYPE:
-        dispatcher.event(SynthesizeStartEvent(query=query))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            SynthesizeStartEvent(
+                query=query,
+            )
+        )
 
         if len(nodes) == 0:
             if self._streaming:
                 empty_response = StreamingResponse(
                     response_gen=empty_response_generator()
                 )
-                dispatcher.event(
-                    SynthesizeEndEvent(query=query, response=empty_response)
+                dispatch_event(
+                    SynthesizeEndEvent(
+                        query=query,
+                        response=empty_response,
+                    )
                 )
                 return empty_response
             else:
                 empty_response = Response("Empty Response")
-                dispatcher.event(
-                    SynthesizeEndEvent(query=query, response=empty_response)
+                dispatch_event(
+                    SynthesizeEndEvent(
+                        query=query,
+                        response=empty_response,
+                    )
                 )
                 return empty_response
 
@@ -223,7 +235,8 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
             query = QueryBundle(query_str=query)
 
         with self._callback_manager.event(
-            CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
+            CBEventType.SYNTHESIZE,
+            payload={EventPayload.QUERY_STR: query.query_str},
         ) as event:
             response_str = self.get_response(
                 query_str=query.query_str,
@@ -240,7 +253,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
 
             event.on_end(payload={EventPayload.RESPONSE: response})
 
-        dispatcher.event(SynthesizeEndEvent(query=query, response=response))
+        dispatch_event(
+            SynthesizeEndEvent(
+                query=query,
+                response=response,
+            )
+        )
         return response
 
     @dispatcher.span
@@ -251,20 +269,32 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
         additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
         **response_kwargs: Any,
     ) -> RESPONSE_TYPE:
-        dispatcher.event(SynthesizeStartEvent(query=query))
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(
+            SynthesizeStartEvent(
+                query=query,
+            )
+        )
         if len(nodes) == 0:
             if self._streaming:
                 empty_response = AsyncStreamingResponse(
                     response_gen=empty_response_agenerator()
                 )
-                dispatcher.event(
-                    SynthesizeEndEvent(query=query, response=empty_response)
+                dispatch_event(
+                    SynthesizeEndEvent(
+                        query=query,
+                        response=empty_response,
+                    )
                 )
                 return empty_response
             else:
                 empty_response = Response("Empty Response")
-                dispatcher.event(
-                    SynthesizeEndEvent(query=query, response=empty_response)
+                dispatch_event(
+                    SynthesizeEndEvent(
+                        query=query,
+                        response=empty_response,
+                    )
                 )
                 return empty_response
 
@@ -272,7 +302,8 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
             query = QueryBundle(query_str=query)
 
         with self._callback_manager.event(
-            CBEventType.SYNTHESIZE, payload={EventPayload.QUERY_STR: query.query_str}
+            CBEventType.SYNTHESIZE,
+            payload={EventPayload.QUERY_STR: query.query_str},
         ) as event:
             response_str = await self.aget_response(
                 query_str=query.query_str,
@@ -289,7 +320,12 @@ class BaseSynthesizer(ChainableMixin, PromptMixin):
 
             event.on_end(payload={EventPayload.RESPONSE: response})
 
-        dispatcher.event(SynthesizeEndEvent(query=query, response=response))
+        dispatch_event(
+            SynthesizeEndEvent(
+                query=query,
+                response=response,
+            )
+        )
         return response
 
     def _as_query_component(self, **kwargs: Any) -> QueryComponent:
diff --git a/llama-index-core/llama_index/core/response_synthesizers/refine.py b/llama-index-core/llama_index/core/response_synthesizers/refine.py
index 61bb4564bc..7b7f753266 100644
--- a/llama-index-core/llama_index/core/response_synthesizers/refine.py
+++ b/llama-index-core/llama_index/core/response_synthesizers/refine.py
@@ -172,7 +172,9 @@ class Refine(BaseSynthesizer):
         **response_kwargs: Any,
     ) -> RESPONSE_TEXT_TYPE:
         """Give response over chunks."""
-        dispatcher.event(GetResponseStartEvent())
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(GetResponseStartEvent())
         response: Optional[RESPONSE_TEXT_TYPE] = None
         for text_chunk in text_chunks:
             if prev_response is None:
@@ -194,7 +196,7 @@ class Refine(BaseSynthesizer):
                 response = response or "Empty Response"
         else:
             response = cast(Generator, response)
-        dispatcher.event(GetResponseEndEvent())
+        dispatch_event(GetResponseEndEvent())
         return response
 
     def _default_program_factory(self, prompt: PromptTemplate) -> BasePydanticProgram:
@@ -350,7 +352,9 @@ class Refine(BaseSynthesizer):
         prev_response: Optional[RESPONSE_TEXT_TYPE] = None,
         **response_kwargs: Any,
     ) -> RESPONSE_TEXT_TYPE:
-        dispatcher.event(GetResponseStartEvent())
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(GetResponseStartEvent())
         response: Optional[RESPONSE_TEXT_TYPE] = None
         for text_chunk in text_chunks:
             if prev_response is None:
@@ -373,6 +377,7 @@ class Refine(BaseSynthesizer):
                 response = response or "Empty Response"
         else:
             response = cast(AsyncGenerator, response)
+        dispatch_event(GetResponseEndEvent())
         return response
 
     async def _arefine_response_single(
diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py
index c256719b4b..f7269f11b3 100644
--- a/llama-index-core/tests/instrumentation/test_dispatcher.py
+++ b/llama-index-core/tests/instrumentation/test_dispatcher.py
@@ -1,9 +1,13 @@
+import asyncio
 import inspect
 from asyncio import CancelledError
+from collections import Counter
 
 import pytest
 import llama_index.core.instrumentation as instrument
 from llama_index.core.instrumentation.dispatcher import Dispatcher
+from llama_index.core.instrumentation.events import BaseEvent
+from llama_index.core.instrumentation.event_handlers import BaseEventHandler
 from unittest.mock import patch, MagicMock
 
 dispatcher = instrument.get_dispatcher("test")
@@ -12,43 +16,98 @@ value_error = ValueError("value error")
 cancelled_error = CancelledError("cancelled error")
 
 
+class _TestStartEvent(BaseEvent):
+    @classmethod
+    def class_name(cls):
+        return "_TestStartEvent"
+
+
+class _TestEndEvent(BaseEvent):
+    @classmethod
+    def class_name(cls):
+        return "_TestEndEvent"
+
+
+class _TestEventHandler(BaseEventHandler):
+    events = []
+
+    @classmethod
+    def class_name(cls):
+        return "_TestEventHandler"
+
+    def handle(self, e: BaseEvent):
+        self.events.append(e)
+
+
 @dispatcher.span
-def func(*args, a, b=3, **kwargs):
+def func(a, b=3, **kwargs):
     return a + b
 
 
 @dispatcher.span
-async def async_func(*args, a, b=3, **kwargs):
+async def async_func(a, b=3, **kwargs):
     return a + b
 
 
 @dispatcher.span
-def func_exc(*args, a, b=3, c=4, **kwargs):
+def func_exc(a, b=3, c=4, **kwargs):
     raise value_error
 
 
 @dispatcher.span
-async def async_func_exc(*args, a, b=3, c=4, **kwargs):
+async def async_func_exc(a, b=3, c=4, **kwargs):
     raise cancelled_error
 
 
+@dispatcher.span
+def func_with_event(a, b=3, **kwargs):
+    dispatch_event = dispatcher.get_dispatch_event()
+
+    dispatch_event(_TestStartEvent())
+
+
+@dispatcher.span
+async def async_func_with_event(a, b=3, **kwargs):
+    dispatch_event = dispatcher.get_dispatch_event()
+
+    dispatch_event(_TestStartEvent())
+    await asyncio.sleep(0.1)
+    dispatch_event(_TestEndEvent())
+
+
 class _TestObject:
     @dispatcher.span
-    def func(self, *args, a, b=3, **kwargs):
+    def func(self, a, b=3, **kwargs):
         return a + b
 
     @dispatcher.span
-    async def async_func(self, *args, a, b=3, **kwargs):
+    async def async_func(self, a, b=3, **kwargs):
         return a + b
 
     @dispatcher.span
-    def func_exc(self, *args, a, b=3, c=4, **kwargs):
+    def func_exc(self, a, b=3, c=4, **kwargs):
         raise value_error
 
     @dispatcher.span
-    async def async_func_exc(self, *args, a, b=3, c=4, **kwargs):
+    async def async_func_exc(self, a, b=3, c=4, **kwargs):
         raise cancelled_error
 
+    @dispatcher.span
+    def func_with_event(self, a, b=3, **kwargs):
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(_TestStartEvent())
+
+    @dispatcher.span
+    async def async_func_with_event(self, a, b=3, **kwargs):
+        dispatch_event = dispatcher.get_dispatch_event()
+
+        dispatch_event(_TestStartEvent())
+        await asyncio.sleep(0.1)
+        await self.async_func(1)  # this should create a new span_id
+        # that is fine because we have dispatch_event
+        dispatch_event(_TestEndEvent())
+
 
 @patch.object(Dispatcher, "span_exit")
 @patch.object(Dispatcher, "span_enter")
@@ -58,12 +117,12 @@ def test_dispatcher_span_args(mock_uuid, mock_span_enter, mock_span_exit):
     mock_uuid.uuid4.return_value = "mock"
 
     # act
-    result = func(1, 2, a=3, c=5)
+    result = func(3, c=5)
 
     # assert
     # span_enter
     span_id = f"{func.__qualname__}-mock"
-    bound_args = inspect.signature(func).bind(1, 2, a=3, c=5)
+    bound_args = inspect.signature(func).bind(3, c=5)
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
@@ -89,12 +148,12 @@ def test_dispatcher_span_args_with_instance(mock_uuid, mock_span_enter, mock_spa
 
     # act
     instance = _TestObject()
-    result = instance.func(1, 2, a=3, c=5)
+    result = instance.func(3, c=5)
 
     # assert
     # span_enter
     span_id = f"{instance.func.__qualname__}-mock"
-    bound_args = inspect.signature(instance.func).bind(1, 2, a=3, c=5)
+    bound_args = inspect.signature(instance.func).bind(3, c=5)
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
@@ -126,7 +185,7 @@ def test_dispatcher_span_drop_args(
 
     with pytest.raises(ValueError):
         # act
-        _ = func_exc(7, a=3, b=5, c=2, d=5)
+        _ = func_exc(3, b=5, c=2, d=5)
 
     # assert
     # span_enter
@@ -135,7 +194,7 @@ def test_dispatcher_span_drop_args(
     # span_drop
     mock_span_drop.assert_called_once()
     span_id = f"{func_exc.__qualname__}-mock"
-    bound_args = inspect.signature(func_exc).bind(7, a=3, b=5, c=2, d=5)
+    bound_args = inspect.signature(func_exc).bind(3, b=5, c=2, d=5)
     args, kwargs = mock_span_drop.call_args
     assert args == ()
     assert kwargs == {
@@ -165,7 +224,7 @@ def test_dispatcher_span_drop_args(
     with pytest.raises(ValueError):
         # act
         instance = _TestObject()
-        _ = instance.func_exc(7, a=3, b=5, c=2, d=5)
+        _ = instance.func_exc(a=3, b=5, c=2, d=5)
 
     # assert
     # span_enter
@@ -174,7 +233,7 @@ def test_dispatcher_span_drop_args(
     # span_drop
     mock_span_drop.assert_called_once()
     span_id = f"{instance.func_exc.__qualname__}-mock"
-    bound_args = inspect.signature(instance.func_exc).bind(7, a=3, b=5, c=2, d=5)
+    bound_args = inspect.signature(instance.func_exc).bind(a=3, b=5, c=2, d=5)
     args, kwargs = mock_span_drop.call_args
     assert args == ()
     assert kwargs == {
@@ -197,12 +256,12 @@ async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_
     mock_uuid.uuid4.return_value = "mock"
 
     # act
-    result = await async_func(1, 2, a=3, c=5)
+    result = await async_func(a=3, c=5)
 
     # assert
     # span_enter
     span_id = f"{async_func.__qualname__}-mock"
-    bound_args = inspect.signature(async_func).bind(1, 2, a=3, c=5)
+    bound_args = inspect.signature(async_func).bind(a=3, c=5)
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
@@ -231,12 +290,12 @@ async def test_dispatcher_async_span_args_with_instance(
 
     # act
     instance = _TestObject()
-    result = await instance.async_func(1, 2, a=3, c=5)
+    result = await instance.async_func(a=3, c=5)
 
     # assert
     # span_enter
     span_id = f"{instance.async_func.__qualname__}-mock"
-    bound_args = inspect.signature(instance.async_func).bind(1, 2, a=3, c=5)
+    bound_args = inspect.signature(instance.async_func).bind(a=3, c=5)
     mock_span_enter.assert_called_once()
     args, kwargs = mock_span_enter.call_args
     assert args == ()
@@ -269,7 +328,7 @@ async def test_dispatcher_async_span_drop_args(
 
     with pytest.raises(CancelledError):
         # act
-        _ = await async_func_exc(7, a=3, b=5, c=2, d=5)
+        _ = await async_func_exc(a=3, b=5, c=2, d=5)
 
     # assert
     # span_enter
@@ -278,7 +337,7 @@ async def test_dispatcher_async_span_drop_args(
     # span_drop
     mock_span_drop.assert_called_once()
     span_id = f"{async_func_exc.__qualname__}-mock"
-    bound_args = inspect.signature(async_func_exc).bind(7, a=3, b=5, c=2, d=5)
+    bound_args = inspect.signature(async_func_exc).bind(a=3, b=5, c=2, d=5)
     args, kwargs = mock_span_drop.call_args
     assert args == ()
     assert kwargs == {
@@ -309,7 +368,7 @@ async def test_dispatcher_async_span_drop_args_with_instance(
     with pytest.raises(CancelledError):
         # act
         instance = _TestObject()
-        _ = await instance.async_func_exc(7, a=3, b=5, c=2, d=5)
+        _ = await instance.async_func_exc(a=3, b=5, c=2, d=5)
 
     # assert
     # span_enter
@@ -318,7 +377,7 @@ async def test_dispatcher_async_span_drop_args_with_instance(
     # span_drop
     mock_span_drop.assert_called_once()
     span_id = f"{instance.async_func_exc.__qualname__}-mock"
-    bound_args = inspect.signature(instance.async_func_exc).bind(7, a=3, b=5, c=2, d=5)
+    bound_args = inspect.signature(instance.async_func_exc).bind(a=3, b=5, c=2, d=5)
     args, kwargs = mock_span_drop.call_args
     assert args == ()
     assert kwargs == {
@@ -330,3 +389,138 @@ async def test_dispatcher_async_span_drop_args_with_instance(
 
     # span_exit
     mock_span_exit.assert_not_called()
+
+
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+def test_dispatcher_fire_event(
+    mock_uuid: MagicMock,
+    mock_span_enter: MagicMock,
+    mock_span_drop: MagicMock,
+    mock_span_exit: MagicMock,
+):
+    # arrange
+    mock_uuid.uuid4.return_value = "mock"
+    event_handler = _TestEventHandler()
+    dispatcher.add_event_handler(event_handler)
+
+    # act
+    _ = func_with_event(3, c=5)
+
+    # assert
+    span_id = f"{func_with_event.__qualname__}-mock"
+    assert all(e.span_id == span_id for e in event_handler.events)
+
+    # span_enter
+    mock_span_enter.assert_called_once()
+
+    # span
+    mock_span_drop.assert_not_called()
+
+    # span_exit
+    mock_span_exit.assert_called_once()
+
+
+@pytest.mark.asyncio()
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+async def test_dispatcher_async_fire_event(
+    mock_span_enter: MagicMock,
+    mock_span_drop: MagicMock,
+    mock_span_exit: MagicMock,
+):
+    # arrange
+    event_handler = _TestEventHandler()
+    dispatcher.add_event_handler(event_handler)
+
+    # act
+    tasks = [
+        async_func_with_event(a=3, c=5),
+        async_func_with_event(5),
+        async_func_with_event(4),
+    ]
+    _ = await asyncio.gather(*tasks)
+
+    # assert
+    span_ids = [e.span_id for e in event_handler.events]
+    id_counts = Counter(span_ids)
+    assert set(id_counts.values()) == {2}
+
+    # span_enter
+    mock_span_enter.call_count == 3
+
+    # span
+    mock_span_drop.assert_not_called()
+
+    # span_exit
+    mock_span_exit.call_count == 3
+
+
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+def test_dispatcher_fire_event_with_instance(
+    mock_uuid, mock_span_enter, mock_span_drop, mock_span_exit
+):
+    # arrange
+    mock_uuid.uuid4.return_value = "mock"
+    event_handler = _TestEventHandler()
+    dispatcher.add_event_handler(event_handler)
+
+    # act
+    instance = _TestObject()
+    _ = instance.func_with_event(a=3, c=5)
+
+    # assert
+    span_id = f"{instance.func_with_event.__qualname__}-mock"
+    assert all(e.span_id == span_id for e in event_handler.events)
+
+    # span_enter
+    mock_span_enter.assert_called_once()
+
+    # span
+    mock_span_drop.assert_not_called()
+
+    # span_exit
+    mock_span_exit.assert_called_once()
+
+
+@pytest.mark.asyncio()
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+async def test_dispatcher_async_fire_event_with_instance(
+    mock_span_enter: MagicMock,
+    mock_span_drop: MagicMock,
+    mock_span_exit: MagicMock,
+):
+    # arrange
+    # mock_uuid.return_value = "mock"
+    event_handler = _TestEventHandler()
+    dispatcher.add_event_handler(event_handler)
+
+    # act
+    instance = _TestObject()
+    tasks = [
+        instance.async_func_with_event(a=3, c=5),
+        instance.async_func_with_event(5),
+    ]
+    _ = await asyncio.gather(*tasks)
+
+    # assert
+    span_ids = [e.span_id for e in event_handler.events]
+    id_counts = Counter(span_ids)
+    assert set(id_counts.values()) == {2}
+
+    # span_enter
+    mock_span_enter.call_count == 2
+
+    # span
+    mock_span_drop.assert_not_called()
+
+    # span_exit
+    mock_span_exit.call_count == 2
diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD
deleted file mode 100644
index dabf212d7e..0000000000
--- a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/BUILD
+++ /dev/null
@@ -1 +0,0 @@
-python_tests()
diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py b/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py
deleted file mode 100644
index c1d54abd7c..0000000000
--- a/llama-index-integrations/callbacks/llama-index-callbacks-aim/tests/test_aim_callback.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from llama_index.callbacks.aim.base import AimCallback
-from llama_index.core.callbacks.base_handler import BaseCallbackHandler
-
-
-def test_class():
-    names_of_base_classes = [b.__name__ for b in AimCallback.__mro__]
-    assert BaseCallbackHandler.__name__ in names_of_base_classes
-- 
GitLab