From 070104d0bd81e2cc2b91d7d3e7d490488cc7a195 Mon Sep 17 00:00:00 2001
From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com>
Date: Tue, 26 Mar 2024 18:18:53 -0400
Subject: [PATCH] Small fixes to instrumentation (#12287)

span_handlers in root dispatcher; return Optional[T] to handle null handler case
---
 .../core/instrumentation/__init__.py          |  2 +-
 .../instrumentation/span_handlers/base.py     | 22 ++++++++++---------
 .../instrumentation/span_handlers/simple.py   | 10 +++++----
 3 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/llama-index-core/llama_index/core/instrumentation/__init__.py b/llama-index-core/llama_index/core/instrumentation/__init__.py
index a060bf2f09..b7564a02f7 100644
--- a/llama-index-core/llama_index/core/instrumentation/__init__.py
+++ b/llama-index-core/llama_index/core/instrumentation/__init__.py
@@ -5,7 +5,7 @@ from llama_index.core.instrumentation.span_handlers import NullSpanHandler
 root_dispatcher: Dispatcher = Dispatcher(
     name="root",
     event_handlers=[NullEventHandler()],
-    span_handler=NullSpanHandler(),
+    span_handlers=[NullSpanHandler()],
     propagate=False,
 )
 
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
index f18cf658ad..a54a8d7b30 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
@@ -37,17 +37,19 @@ class BaseSpanHandler(BaseModel, Generic[T]):
 
     def span_exit(self, *args, id: str, result: Optional[Any] = None, **kwargs) -> None:
         """Logic for exiting a span."""
-        self.prepare_to_exit_span(*args, id=id, result=result, **kwargs)
-        if self.current_span_id == id:
-            self.current_span_id = self.open_spans[id].parent_id
-        del self.open_spans[id]
+        span = self.prepare_to_exit_span(*args, id=id, result=result, **kwargs)
+        if span:
+            if self.current_span_id == id:
+                self.current_span_id = self.open_spans[id].parent_id
+            del self.open_spans[id]
 
     def span_drop(self, *args, id: str, err: Optional[Exception], **kwargs) -> None:
         """Logic for dropping a span i.e. early exit."""
-        self.prepare_to_drop_span(*args, id=id, err=err, **kwargs)
-        if self.current_span_id == id:
-            self.current_span_id = self.open_spans[id].parent_id
-        del self.open_spans[id]
+        span = self.prepare_to_drop_span(*args, id=id, err=err, **kwargs)
+        if span:
+            if self.current_span_id == id:
+                self.current_span_id = self.open_spans[id].parent_id
+            del self.open_spans[id]
 
     @abstractmethod
     def new_span(
@@ -59,13 +61,13 @@ class BaseSpanHandler(BaseModel, Generic[T]):
     @abstractmethod
     def prepare_to_exit_span(
         self, *args, id: str, result: Optional[Any] = None, **kwargs
-    ) -> Any:
+    ) -> Optional[T]:
         """Logic for preparing to exit a span."""
         ...
 
     @abstractmethod
     def prepare_to_drop_span(
         self, *args, id: str, err: Optional[Exception], **kwargs
-    ) -> Any:
+    ) -> Optional[T]:
         """Logic for preparing to drop a span."""
         ...
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
index a1ceafd879..96c54b2fb2 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
@@ -28,20 +28,22 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
 
     def prepare_to_exit_span(
         self, *args, id: str, result: Optional[Any] = None, **kwargs
-    ) -> None:
+    ) -> SimpleSpan:
         """Logic for preparing to drop a span."""
         span = self.open_spans[id]
         span = cast(SimpleSpan, span)
         span.end_time = datetime.now()
         span.duration = (span.end_time - span.start_time).total_seconds()
         self.completed_spans += [span]
+        return span
 
     def prepare_to_drop_span(
         self, *args, id: str, err: Optional[Exception], **kwargs
-    ) -> None:
+    ) -> SimpleSpan:
         """Logic for droppping a span."""
-        if err:
-            raise err
+        if id in self.open_spans:
+            return self.open_spans[id]
+        return None
 
     def _get_trace_trees(self) -> List["Tree"]:
         """Method for getting trace trees."""
-- 
GitLab