From c8c0ac9a89070b2732c398f09457c1f35c1ac348 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Thu, 21 Mar 2024 17:27:32 -0400 Subject: [PATCH] Instrumentation enhancements (#12147) --- .../llama_index/core/base/base_retriever.py | 12 +- .../core/instrumentation/dispatcher.py | 24 +-- .../core/instrumentation/events/retrieval.py | 7 + .../instrumentation/span_handlers/base.py | 24 ++- .../instrumentation/span_handlers/null.py | 12 +- .../instrumentation/span_handlers/simple.py | 10 +- llama-index-core/tests/instrumentation/BUILD | 3 + .../tests/instrumentation/test_dispatcher.py | 156 ++++++++++++++++++ .../tests/instrumentation/test_manager.py | 13 ++ 9 files changed, 228 insertions(+), 33 deletions(-) create mode 100644 llama-index-core/tests/instrumentation/BUILD create mode 100644 llama-index-core/tests/instrumentation/test_dispatcher.py create mode 100644 llama-index-core/tests/instrumentation/test_manager.py diff --git a/llama-index-core/llama_index/core/base/base_retriever.py b/llama-index-core/llama_index/core/base/base_retriever.py index 51e56416e5..6dc7649de6 100644 --- a/llama-index-core/llama_index/core/base/base_retriever.py +++ b/llama-index-core/llama_index/core/base/base_retriever.py @@ -225,7 +225,7 @@ class BaseRetriever(ChainableMixin, PromptMixin): """ self._check_callback_manager() - dispatcher.event(RetrievalStartEvent()) + dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle)) if isinstance(str_or_query_bundle, str): query_bundle = QueryBundle(str_or_query_bundle) else: @@ -240,13 +240,15 @@ class BaseRetriever(ChainableMixin, PromptMixin): retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) - dispatcher.event(RetrievalEndEvent()) + dispatcher.event( + RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes) + ) return nodes @dispatcher.span async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: self._check_callback_manager() - dispatcher.event(RetrievalStartEvent()) + dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle)) if isinstance(str_or_query_bundle, str): query_bundle = QueryBundle(str_or_query_bundle) else: @@ -263,7 +265,9 @@ class BaseRetriever(ChainableMixin, PromptMixin): retrieve_event.on_end( payload={EventPayload.NODES: nodes}, ) - dispatcher.event(RetrievalEndEvent()) + dispatcher.event( + RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes) + ) return nodes @abstractmethod diff --git a/llama-index-core/llama_index/core/instrumentation/dispatcher.py b/llama-index-core/llama_index/core/instrumentation/dispatcher.py index c04b921f55..10d8cecb89 100644 --- a/llama-index-core/llama_index/core/instrumentation/dispatcher.py +++ b/llama-index-core/llama_index/core/instrumentation/dispatcher.py @@ -58,34 +58,34 @@ class Dispatcher(BaseModel): else: c = c.parent - def span_enter(self, id: str, **kwargs) -> None: + def span_enter(self, *args, id: str, **kwargs) -> None: """Send notice to handlers that a span with id has started.""" c = self while c: for h in c.span_handlers: - h.span_enter(id, **kwargs) + h.span_enter(*args, id=id, **kwargs) if not c.propagate: c = None else: c = c.parent - def span_drop(self, id: str, err: Optional[Exception], **kwargs) -> None: + def span_drop(self, *args, id: str, err: Optional[Exception], **kwargs) -> None: """Send notice to handlers that a span with id is being dropped.""" c = self while c: for h in c.span_handlers: - h.span_drop(id, err, **kwargs) + h.span_drop(*args, id=id, err=err, **kwargs) if not c.propagate: c = None else: c = c.parent - def span_exit(self, id: str, result: Optional[Any] = None, **kwargs) -> None: + def span_exit(self, *args, id: str, result: Optional[Any] = None, **kwargs) -> None: """Send notice to handlers that a span with id is exiting.""" c = self while c: for h in c.span_handlers: - h.span_exit(id, result, **kwargs) + h.span_exit(*args, id=id, result=result, **kwargs) if not c.propagate: c = None else: @@ -95,25 +95,25 @@ class Dispatcher(BaseModel): @functools.wraps(func) def wrapper(*args, **kwargs): id = f"{func.__qualname__}-{uuid.uuid4()}" - self.span_enter(id=id, **kwargs) + self.span_enter(*args, id=id, **kwargs) try: result = func(*args, **kwargs) except Exception as e: - self.span_drop(id=id, err=e) + self.span_drop(*args, id=id, err=e, **kwargs) else: - self.span_exit(id=id, result=result) + self.span_exit(*args, id=id, result=result, **kwargs) return result @functools.wraps(func) async def async_wrapper(*args, **kwargs): id = f"{func.__qualname__}-{uuid.uuid4()}" - self.span_enter(id=id, **kwargs) + self.span_enter(*args, id=id, **kwargs) try: result = await func(*args, **kwargs) except Exception as e: - self.span_drop(id=id, err=e) + self.span_drop(*args, id=id, err=e, **kwargs) else: - self.span_exit(id=id, result=result) + self.span_exit(*args, id=id, result=result, **kwargs) return result if inspect.iscoroutinefunction(func): diff --git a/llama-index-core/llama_index/core/instrumentation/events/retrieval.py b/llama-index-core/llama_index/core/instrumentation/events/retrieval.py index c62b80e7cc..010b092ab2 100644 --- a/llama-index-core/llama_index/core/instrumentation/events/retrieval.py +++ b/llama-index-core/llama_index/core/instrumentation/events/retrieval.py @@ -1,7 +1,11 @@ +from typing import List from llama_index.core.instrumentation.events.base import BaseEvent +from llama_index.core.schema import QueryType, NodeWithScore class RetrievalStartEvent(BaseEvent): + str_or_query_bundle: QueryType + @classmethod def class_name(cls): """Class name.""" @@ -9,6 +13,9 @@ class RetrievalStartEvent(BaseEvent): class RetrievalEndEvent(BaseEvent): + str_or_query_bundle: QueryType + nodes: List[NodeWithScore] + @classmethod def class_name(cls): """Class name.""" diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py index 396ef45af4..f18cf658ad 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/base.py @@ -22,44 +22,50 @@ class BaseSpanHandler(BaseModel, Generic[T]): """Class name.""" return "BaseSpanHandler" - def span_enter(self, id: str, **kwargs) -> None: + def span_enter(self, *args, id: str, **kwargs) -> None: """Logic for entering a span.""" if id in self.open_spans: pass # should probably raise an error here else: # TODO: thread safe? - span = self.new_span(id=id, parent_span_id=self.current_span_id, **kwargs) + span = self.new_span( + *args, id=id, parent_span_id=self.current_span_id, **kwargs + ) if span: self.open_spans[id] = span self.current_span_id = id - def span_exit(self, id: str, result: Optional[Any] = None, **kwargs) -> None: + def span_exit(self, *args, id: str, result: Optional[Any] = None, **kwargs) -> None: """Logic for exiting a span.""" - self.prepare_to_exit_span(id, result=result, **kwargs) + self.prepare_to_exit_span(*args, id=id, result=result, **kwargs) if self.current_span_id == id: self.current_span_id = self.open_spans[id].parent_id del self.open_spans[id] - def span_drop(self, id: str, err: Optional[Exception], **kwargs) -> None: + def span_drop(self, *args, id: str, err: Optional[Exception], **kwargs) -> None: """Logic for dropping a span i.e. early exit.""" - self.prepare_to_drop_span(id, err, **kwargs) + self.prepare_to_drop_span(*args, id=id, err=err, **kwargs) if self.current_span_id == id: self.current_span_id = self.open_spans[id].parent_id del self.open_spans[id] @abstractmethod - def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> Optional[T]: + def new_span( + self, *args, id: str, parent_span_id: Optional[str], **kwargs + ) -> Optional[T]: """Create a span.""" ... @abstractmethod def prepare_to_exit_span( - self, id: str, result: Optional[Any] = None, **kwargs + self, *args, id: str, result: Optional[Any] = None, **kwargs ) -> Any: """Logic for preparing to exit a span.""" ... @abstractmethod - def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> Any: + def prepare_to_drop_span( + self, *args, id: str, err: Optional[Exception], **kwargs + ) -> Any: """Logic for preparing to drop a span.""" ... diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py index 47456b051c..fccc12c42c 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/null.py @@ -9,25 +9,27 @@ class NullSpanHandler(BaseSpanHandler[BaseSpan]): """Class name.""" return "NullSpanHandler" - def span_enter(self, id: str, **kwargs) -> None: + def span_enter(self, *args, id: str, **kwargs) -> None: """Logic for entering a span.""" return - def span_exit(self, id: str, result: Optional[Any], **kwargs) -> None: + def span_exit(self, *args, id: str, result: Optional[Any], **kwargs) -> None: """Logic for exiting a span.""" return - def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> None: + def new_span(self, *args, id: str, parent_span_id: Optional[str], **kwargs) -> None: """Create a span.""" return def prepare_to_exit_span( - self, id: str, result: Optional[Any] = None, **kwargs + self, *args, id: str, result: Optional[Any] = None, **kwargs ) -> None: """Logic for exiting a span.""" return - def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> None: + def prepare_to_drop_span( + self, *args, id: str, err: Optional[Exception], **kwargs + ) -> None: """Logic for droppping a span.""" if err: raise err diff --git a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py index 1cea445b77..a1ceafd879 100644 --- a/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py +++ b/llama-index-core/llama_index/core/instrumentation/span_handlers/simple.py @@ -20,12 +20,14 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): """Class name.""" return "SimpleSpanHandler" - def new_span(self, id: str, parent_span_id: Optional[str], **kwargs) -> SimpleSpan: + def new_span( + self, *args, id: str, parent_span_id: Optional[str], **kwargs + ) -> SimpleSpan: """Create a span.""" return SimpleSpan(id_=id, parent_id=parent_span_id) def prepare_to_exit_span( - self, id: str, result: Optional[Any] = None, **kwargs + self, *args, id: str, result: Optional[Any] = None, **kwargs ) -> None: """Logic for preparing to drop a span.""" span = self.open_spans[id] @@ -34,7 +36,9 @@ class SimpleSpanHandler(BaseSpanHandler[SimpleSpan]): span.duration = (span.end_time - span.start_time).total_seconds() self.completed_spans += [span] - def prepare_to_drop_span(self, id: str, err: Optional[Exception], **kwargs) -> None: + def prepare_to_drop_span( + self, *args, id: str, err: Optional[Exception], **kwargs + ) -> None: """Logic for droppping a span.""" if err: raise err diff --git a/llama-index-core/tests/instrumentation/BUILD b/llama-index-core/tests/instrumentation/BUILD new file mode 100644 index 0000000000..57341b1358 --- /dev/null +++ b/llama-index-core/tests/instrumentation/BUILD @@ -0,0 +1,3 @@ +python_tests( + name="tests", +) diff --git a/llama-index-core/tests/instrumentation/test_dispatcher.py b/llama-index-core/tests/instrumentation/test_dispatcher.py new file mode 100644 index 0000000000..4805a959bb --- /dev/null +++ b/llama-index-core/tests/instrumentation/test_dispatcher.py @@ -0,0 +1,156 @@ +import pytest +import llama_index.core.instrumentation as instrument +from llama_index.core.instrumentation.dispatcher import Dispatcher +from unittest.mock import patch, MagicMock + +dispatcher = instrument.get_dispatcher("test") + + +@dispatcher.span +def func(*args, a, b=3, **kwargs): + return a + b + + +@dispatcher.span +async def async_func(*args, a, b=3, **kwargs): + return a + b + + +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +def test_dispatcher_span_args(mock_uuid, mock_span_enter, mock_span_exit): + # arrange + mock_uuid.uuid4.return_value = "mock" + + # act + result = func(1, 2, a=3, c=5) + + # assert + # span_enter + span_id = f"{func.__qualname__}-mock" + mock_span_enter.assert_called_once() + args, kwargs = mock_span_enter.call_args + assert args == (1, 2) + assert kwargs == {"id": span_id, "a": 3, "c": 5} + + # span_exit + args, kwargs = mock_span_exit.call_args + assert args == (1, 2) + assert kwargs == {"id": span_id, "a": 3, "c": 5, "result": result} + + +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +@patch(f"{__name__}.func") +def test_dispatcher_span_drop_args( + mock_func: MagicMock, + mock_uuid: MagicMock, + mock_span_enter: MagicMock, + mock_span_drop: MagicMock, + mock_span_exit: MagicMock, +): + # arrange + class CustomException(Exception): + pass + + mock_uuid.uuid4.return_value = "mock" + mock_func.side_effect = CustomException + + with pytest.raises(CustomException): + # act + result = func(7, a=3, b=5, c=2, d=5) + + # assert + # span_enter + mock_span_enter.assert_called_once() + + # span_drop + mock_span_drop.assert_called_once() + span_id = f"{func.__qualname__}-mock" + args, kwargs = mock_span_exit.call_args + assert args == (7,) + assert kwargs == { + "id": span_id, + "a": 3, + "b": 5, + "c": 2, + "d": 2, + "err": CustomException, + } + + # span_exit + mock_span_exit.assert_not_called() + + +@pytest.mark.asyncio() +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +async def test_dispatcher_async_span_args(mock_uuid, mock_span_enter, mock_span_exit): + # arrange + mock_uuid.uuid4.return_value = "mock" + + # act + result = await async_func(1, 2, a=3, c=5) + + # assert + # span_enter + span_id = f"{async_func.__qualname__}-mock" + mock_span_enter.assert_called_once() + args, kwargs = mock_span_enter.call_args + assert args == (1, 2) + assert kwargs == {"id": span_id, "a": 3, "c": 5} + + # span_exit + args, kwargs = mock_span_exit.call_args + assert args == (1, 2) + assert kwargs == {"id": span_id, "a": 3, "c": 5, "result": result} + + +@pytest.mark.asyncio() +@patch.object(Dispatcher, "span_exit") +@patch.object(Dispatcher, "span_drop") +@patch.object(Dispatcher, "span_enter") +@patch("llama_index.core.instrumentation.dispatcher.uuid") +@patch(f"{__name__}.async_func") +async def test_dispatcher_aysnc_span_drop_args( + mock_func: MagicMock, + mock_uuid: MagicMock, + mock_span_enter: MagicMock, + mock_span_drop: MagicMock, + mock_span_exit: MagicMock, +): + # arrange + class CustomException(Exception): + pass + + mock_uuid.uuid4.return_value = "mock" + mock_func.side_effect = CustomException + + with pytest.raises(CustomException): + # act + result = await async_func(7, a=3, b=5, c=2, d=5) + + # assert + # span_enter + mock_span_enter.assert_called_once() + + # span_drop + mock_span_drop.assert_called_once() + span_id = f"{func.__qualname__}-mock" + args, kwargs = mock_span_exit.call_args + assert args == (7,) + assert kwargs == { + "id": span_id, + "a": 3, + "b": 5, + "c": 2, + "d": 2, + "err": CustomException, + } + + # span_exit + mock_span_exit.assert_not_called() diff --git a/llama-index-core/tests/instrumentation/test_manager.py b/llama-index-core/tests/instrumentation/test_manager.py new file mode 100644 index 0000000000..0739dbc4ec --- /dev/null +++ b/llama-index-core/tests/instrumentation/test_manager.py @@ -0,0 +1,13 @@ +import llama_index.core.instrumentation as instrument + + +def test_root_manager_add_dispatcher(): + # arrange + root_manager = instrument.root_manager + + # act + dispatcher = instrument.get_dispatcher("test") + + # assert + assert "root" in root_manager.dispatchers + assert "test" in root_manager.dispatchers -- GitLab