From c8c0ac9a89070b2732c398f09457c1f35c1ac348 Mon Sep 17 00:00:00 2001
From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com>
Date: Thu, 21 Mar 2024 17:27:32 -0400
Subject: [PATCH] Instrumentation enhancements (#12147)

---
 .../llama_index/core/base/base_retriever.py   |  12 +-
 .../core/instrumentation/dispatcher.py        |  24 +--
 .../core/instrumentation/events/retrieval.py  |   7 +
 .../instrumentation/span_handlers/base.py     |  24 ++-
 .../instrumentation/span_handlers/null.py     |  12 +-
 .../instrumentation/span_handlers/simple.py   |  10 +-
 llama-index-core/tests/instrumentation/BUILD  |   3 +
 .../tests/instrumentation/test_dispatcher.py  | 156 ++++++++++++++++++
 .../tests/instrumentation/test_manager.py     |  13 ++
 9 files changed, 228 insertions(+), 33 deletions(-)
 create mode 100644 llama-index-core/tests/instrumentation/BUILD
 create mode 100644 llama-index-core/tests/instrumentation/test_dispatcher.py
 create mode 100644 llama-index-core/tests/instrumentation/test_manager.py

diff --git a/llama-index-core/llama_index/core/base/base_retriever.py b/llama-index-core/llama_index/core/base/base_retriever.py
index 51e56416e5..6dc7649de6 100644
--- a/llama-index-core/llama_index/core/base/base_retriever.py
+++ b/llama-index-core/llama_index/core/base/base_retriever.py
@@ -225,7 +225,7 @@ class BaseRetriever(ChainableMixin, PromptMixin):
 
         """
         self._check_callback_manager()
-        dispatcher.event(RetrievalStartEvent())
+        dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
         if isinstance(str_or_query_bundle, str):
             query_bundle = QueryBundle(str_or_query_bundle)
         else:
@@ -240,13 +240,15 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 retrieve_event.on_end(
                     payload={EventPayload.NODES: nodes},
                 )
-        dispatcher.event(RetrievalEndEvent())
+        dispatcher.event(
+            RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
+        )
         return nodes
 
     @dispatcher.span
     async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
         self._check_callback_manager()
-        dispatcher.event(RetrievalStartEvent())
+        dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
         if isinstance(str_or_query_bundle, str):
             query_bundle = QueryBundle(str_or_query_bundle)
         else:
@@ -263,7 +265,9 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 retrieve_event.on_end(
                     payload={EventPayload.NODES: nodes},
                 )
-        dispatcher.event(RetrievalEndEvent())
+        dispatcher.event(
+            RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
+        )
         return nodes
 
     @abstractmethod
diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
index c04b921f55..10d8cecb89 100644
--- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py
+++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py
@@ -58,34 +58,34 @@ class Dispatcher(BaseModel):
             else:
                 c = c.parent
 
-    def span_enter(self, id: str, **kwargs) -> None:
+    def span_enter(self, *args, id: str, **kwargs) -> None:
         """Send notice to handlers that a span with id has started."""
         c = self
         while c:
             for h in c.span_handlers:
-                h.span_enter(id, **kwargs)
+                h.span_enter(*args, id=id, **kwargs)
             if not c.propagate:
                 c = None
             else:
                 c = c.parent
 
-    def span_drop(self, id: str, err: Optional[Exception], **kwargs) -> None:
+    def span_drop(self, *args, id: str, err: Optional[Exception], **kwargs) -> None:
         """Send notice to handlers that a span with id is being dropped."""
         c = self
         while c:
             for h in c.span_handlers:
-                h.span_drop(id, err, **kwargs)
+                h.span_drop(*args, id=id, err=err, **kwargs)
             if not c.propagate:
                 c = None
             else:
                 c = c.parent
 
-    def span_exit(self, id: str, result: Optional[Any] = None, **kwargs) -> None:
+    def span_exit(self, *args, id: str, result: Optional[Any] = None, **kwargs) -> None:
         """Send notice to handlers that a span with id is exiting."""
         c = self
         while c:
             for h in c.span_handlers:
-                h.span_exit(id, result, **kwargs)
+                h.span_exit(*args, id=id, result=result, **kwargs)
             if not c.propagate:
                 c = None
             else:
@@ -95,25 +95,25 @@ class Dispatcher(BaseModel):
         @functools.wraps(func)
         def wrapper(*args, **kwargs):
             id = f"{func.__qualname__}-{uuid.uuid4()}"
-            self.span_enter(id=id, **kwargs)
+            self.span_enter(*args, id=id, **kwargs)
             try:
                 result = func(*args, **kwargs)
             except Exception as e:
-                self.span_drop(id=id, err=e)
+                self.span_drop(*args, id=id, err=e, **kwargs)
             else:
-                self.span_exit(id=id, result=result)
+                self.span_exit(*args, id=id, result=result, **kwargs)
                 return result
 
         @functools.wraps(func)
         async def async_wrapper(*args, **kwargs):
             id = f"{func.__qualname__}-{uuid.uuid4()}"
-            self.span_enter(id=id, **kwargs)
+            self.span_enter(*args, id=id, **kwargs)
             try:
                 result = await func(*args, **kwargs)
             except Exception as e:
-                self.span_drop(id=id, err=e)
+                self.span_drop(*args, id=id, err=e, **kwargs)
             else:
-                self.span_exit(id=id, result=result)
+                self.span_exit(*args, id=id, result=result, **kwargs)
                 return result
 
         if inspect.iscoroutinefunction(func):
diff --git a/llama-index-core/llama_index/core/instrumentation/events/retrieval.py b/llama-index-core/llama_index/core/instrumentation/events/retrieval.py
index c62b80e7cc..010b092ab2 100644
--- a/llama-index-core/llama_index/core/instrumentation/events/retrieval.py
+++ b/llama-index-core/llama_index/core/instrumentation/events/retrieval.py
@@ -1,7 +1,11 @@
+from typing import List
 from llama_index.core.instrumentation.events.base import BaseEvent
+from llama_index.core.schema import QueryType, NodeWithScore
 
 
 class RetrievalStartEvent(BaseEvent):
+    str_or_query_bundle: QueryType
+
     @classmethod
     def class_name(cls):
         """Class name."""
@@ -9,6 +13,9 @@ class RetrievalStartEvent(BaseEvent):
 
 
 class RetrievalEndEvent(BaseEvent):
+    str_or_query_bundle: QueryType
+    nodes: List[NodeWithScore]
+
     @classmethod
     def class_name(cls):
         """Class name."""
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
index 396ef45af4..f18cf658ad 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py
@@ -22,44 +22,50 @@ class BaseSpanHandler(BaseModel, Generic[T]):
         """Class name."""
         return "BaseSpanHandler"
 
-    def span_enter(self, id: str, **kwargs) -> None:
+    def span_enter(self, *args, id: str, **kwargs) -> None:
         """Logic for entering a span."""
         if id in self.open_spans:
             pass  # should probably raise an error here
         else:
             # TODO: thread safe?
-            span = self.new_span(id=id, parent_span_id=self.current_span_id, **kwargs)
+            span = self.new_span(
+                *args, id=id, parent_span_id=self.current_span_id, **kwargs
+            )
             if span:
                 self.open_spans[id] = span
                 self.current_span_id = id
 
-    def span_exit(self, id: str, result: Optional[Any] = None, **kwargs) -> None:
+    def span_exit(self, *args, id: str, result: Optional[Any] = None, **kwargs) -> None:
         """Logic for exiting a span."""
-        self.prepare_to_exit_span(id, result=result, **kwargs)
+        self.prepare_to_exit_span(*args, id=id, result=result, **kwargs)
         if self.current_span_id == id:
             self.current_span_id = self.open_spans[id].parent_id
         del self.open_spans[id]
 
-    def span_drop(self, id: str, err: Optional[Exception], **kwargs) -> None:
+    def span_drop(self, *args, id: str, err: Optional[Exception], **kwargs) -> None:
         """Logic for dropping a span i.e. early exit."""
-        self.prepare_to_drop_span(id, err, **kwargs)
+        self.prepare_to_drop_span(*args, id=id, err=err, **kwargs)
         if self.current_span_id == id:
             self.current_span_id = self.open_spans[id].parent_id
         del self.open_spans[id]
 
     @abstractmethod
-    def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> Optional[T]:
+    def new_span(
+        self, *args, id: str, parent_span_id: Optional[str], **kwargs
+    ) -> Optional[T]:
         """Create a span."""
         ...
 
     @abstractmethod
     def prepare_to_exit_span(
-        self, id: str, result: Optional[Any] = None, **kwargs
+        self, *args, id: str, result: Optional[Any] = None, **kwargs
     ) -> Any:
         """Logic for preparing to exit a span."""
         ...
 
     @abstractmethod
-    def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> Any:
+    def prepare_to_drop_span(
+        self, *args, id: str, err: Optional[Exception], **kwargs
+    ) -> Any:
         """Logic for preparing to drop a span."""
         ...
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
index 47456b051c..fccc12c42c 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py
@@ -9,25 +9,27 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]):
         """Class name."""
         return "NullSpanHandler"
 
-    def span_enter(self, id: str, **kwargs) -> None:
+    def span_enter(self, *args, id: str, **kwargs) -> None:
         """Logic for entering a span."""
         return
 
-    def span_exit(self, id: str, result: Optional[Any], **kwargs) -> None:
+    def span_exit(self, *args, id: str, result: Optional[Any], **kwargs) -> None:
         """Logic for exiting a span."""
         return
 
-    def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> None:
+    def new_span(self, *args, id: str, parent_span_id: Optional[str], **kwargs) -> None:
         """Create a span."""
         return
 
     def prepare_to_exit_span(
-        self, id: str, result: Optional[Any] = None, **kwargs
+        self, *args, id: str, result: Optional[Any] = None, **kwargs
     ) -> None:
         """Logic for exiting a span."""
         return
 
-    def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> None:
+    def prepare_to_drop_span(
+        self, *args, id: str, err: Optional[Exception], **kwargs
+    ) -> None:
         """Logic for droppping a span."""
         if err:
             raise err
diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
index 1cea445b77..a1ceafd879 100644
--- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
+++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py
@@ -20,12 +20,14 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
         """Class name."""
         return "SimpleSpanHandler"
 
-    def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> SimpleSpan:
+    def new_span(
+        self, *args, id: str, parent_span_id: Optional[str], **kwargs
+    ) -> SimpleSpan:
         """Create a span."""
         return SimpleSpan(id_=id, parent_id=parent_span_id)
 
     def prepare_to_exit_span(
-        self, id: str, result: Optional[Any] = None, **kwargs
+        self, *args, id: str, result: Optional[Any] = None, **kwargs
     ) -> None:
         """Logic for preparing to drop a span."""
         span = self.open_spans[id]
@@ -34,7 +36,9 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]):
         span.duration = (span.end_time - span.start_time).total_seconds()
         self.completed_spans += [span]
 
-    def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> None:
+    def prepare_to_drop_span(
+        self, *args, id: str, err: Optional[Exception], **kwargs
+    ) -> None:
         """Logic for droppping a span."""
         if err:
             raise err
diff --git a/llama-index-core/tests/instrumentation/BUILD b/llama-index-core/tests/instrumentation/BUILD
new file mode 100644
index 0000000000..57341b1358
--- /dev/null
+++ b/llama-index-core/tests/instrumentation/BUILD
@@ -0,0 +1,3 @@
+python_tests(
+    name="tests",
+)
diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py
new file mode 100644
index 0000000000..4805a959bb
--- /dev/null
+++ b/llama-index-core/tests/instrumentation/test_dispatcher.py
@@ -0,0 +1,156 @@
+import pytest
+import llama_index.core.instrumentation as instrument
+from llama_index.core.instrumentation.dispatcher import Dispatcher
+from unittest.mock import patch, MagicMock
+
+dispatcher = instrument.get_dispatcher("test")
+
+
+@dispatcher.span
+def func(*args, a, b=3, **kwargs):
+    return a + b
+
+
+@dispatcher.span
+async def async_func(*args, a, b=3, **kwargs):
+    return a + b
+
+
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+def test_dispatcher_span_args(mock_uuid, mock_span_enter, mock_span_exit):
+    # arrange
+    mock_uuid.uuid4.return_value = "mock"
+
+    # act
+    result = func(1, 2, a=3, c=5)
+
+    # assert
+    # span_enter
+    span_id = f"{func.__qualname__}-mock"
+    mock_span_enter.assert_called_once()
+    args, kwargs = mock_span_enter.call_args
+    assert args == (1, 2)
+    assert kwargs == {"id": span_id, "a": 3, "c": 5}
+
+    # span_exit
+    args, kwargs = mock_span_exit.call_args
+    assert args == (1, 2)
+    assert kwargs == {"id": span_id, "a": 3, "c": 5, "result": result}
+
+
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+@patch(f"{__name__}.func")
+def test_dispatcher_span_drop_args(
+    mock_func: MagicMock,
+    mock_uuid: MagicMock,
+    mock_span_enter: MagicMock,
+    mock_span_drop: MagicMock,
+    mock_span_exit: MagicMock,
+):
+    # arrange
+    class CustomException(Exception):
+        pass
+
+    mock_uuid.uuid4.return_value = "mock"
+    mock_func.side_effect = CustomException
+
+    with pytest.raises(CustomException):
+        # act
+        result = func(7, a=3, b=5, c=2, d=5)
+
+        # assert
+        # span_enter
+        mock_span_enter.assert_called_once()
+
+        # span_drop
+        mock_span_drop.assert_called_once()
+        span_id = f"{func.__qualname__}-mock"
+        args, kwargs = mock_span_exit.call_args
+        assert args == (7,)
+        assert kwargs == {
+            "id": span_id,
+            "a": 3,
+            "b": 5,
+            "c": 2,
+            "d": 2,
+            "err": CustomException,
+        }
+
+        # span_exit
+        mock_span_exit.assert_not_called()
+
+
+@pytest.mark.asyncio()
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_exit):
+    # arrange
+    mock_uuid.uuid4.return_value = "mock"
+
+    # act
+    result = await async_func(1, 2, a=3, c=5)
+
+    # assert
+    # span_enter
+    span_id = f"{async_func.__qualname__}-mock"
+    mock_span_enter.assert_called_once()
+    args, kwargs = mock_span_enter.call_args
+    assert args == (1, 2)
+    assert kwargs == {"id": span_id, "a": 3, "c": 5}
+
+    # span_exit
+    args, kwargs = mock_span_exit.call_args
+    assert args == (1, 2)
+    assert kwargs == {"id": span_id, "a": 3, "c": 5, "result": result}
+
+
+@pytest.mark.asyncio()
+@patch.object(Dispatcher, "span_exit")
+@patch.object(Dispatcher, "span_drop")
+@patch.object(Dispatcher, "span_enter")
+@patch("llama_index.core.instrumentation.dispatcher.uuid")
+@patch(f"{__name__}.async_func")
+async def test_dispatcher_aysnc_span_drop_args(
+    mock_func: MagicMock,
+    mock_uuid: MagicMock,
+    mock_span_enter: MagicMock,
+    mock_span_drop: MagicMock,
+    mock_span_exit: MagicMock,
+):
+    # arrange
+    class CustomException(Exception):
+        pass
+
+    mock_uuid.uuid4.return_value = "mock"
+    mock_func.side_effect = CustomException
+
+    with pytest.raises(CustomException):
+        # act
+        result = await async_func(7, a=3, b=5, c=2, d=5)
+
+        # assert
+        # span_enter
+        mock_span_enter.assert_called_once()
+
+        # span_drop
+        mock_span_drop.assert_called_once()
+        span_id = f"{func.__qualname__}-mock"
+        args, kwargs = mock_span_exit.call_args
+        assert args == (7,)
+        assert kwargs == {
+            "id": span_id,
+            "a": 3,
+            "b": 5,
+            "c": 2,
+            "d": 2,
+            "err": CustomException,
+        }
+
+        # span_exit
+        mock_span_exit.assert_not_called()
diff --git a/llama-index-core/tests/instrumentation/test_manager.py b/llama-index-core/tests/instrumentation/test_manager.py
new file mode 100644
index 0000000000..0739dbc4ec
--- /dev/null
+++ b/llama-index-core/tests/instrumentation/test_manager.py
@@ -0,0 +1,13 @@
+import llama_index.core.instrumentation as instrument
+
+
+def test_root_manager_add_dispatcher():
+    # arrange
+    root_manager = instrument.root_manager
+
+    # act
+    dispatcher = instrument.get_dispatcher("test")
+
+    # assert
+    assert "root" in root_manager.dispatchers
+    assert "test" in root_manager.dispatchers
-- 
GitLab