From e6c34b97c4628b7f79ca57877f5db79d50cafcf1 Mon Sep 17 00:00:00 2001
From: Roger Yang <80478925+RogerHYang@users.noreply.github.com>
Date: Tue, 21 May 2024 05:32:11 -0700
Subject: [PATCH] feat(instrumentation): new spans for llms (#13565)

---
 .../llama_index/core/base/llms/base.py        | 21 ++++++++++++
 .../llama_index/core/llms/callbacks.py        | 32 +++++++++++++++++++
 2 files changed, 53 insertions(+)

diff --git a/llama-index-core/llama_index/core/base/llms/base.py b/llama-index-core/llama_index/core/base/llms/base.py
index fa7c816980..aa4946fda5 100644
--- a/llama-index-core/llama_index/core/base/llms/base.py
+++ b/llama-index-core/llama_index/core/base/llms/base.py
@@ -4,6 +4,7 @@ from typing import (
     Sequence,
 )
 
+from llama_index.core import instrumentation
 from llama_index.core.base.llms.types import (
     ChatMessage,
     ChatResponse,
@@ -253,3 +254,23 @@ class BaseLLM(ChainableMixin, BaseComponent):
                 print(response.text, end="", flush=True)
             ```
         """
+
+    def __init_subclass__(cls, **kwargs) -> None:
+        """
+        Decorate the abstract methods' implementations for each subclass.
+        `__init_subclass__` is analogous to `__init__` because classes are also objects.
+        """
+        super().__init_subclass__(**kwargs)
+        dispatcher = instrumentation.get_dispatcher(cls.__module__)
+        for attr in (
+            "chat",
+            "complete",
+            "stream_chat",
+            "stream_complete",
+            "achat",
+            "acomplete",
+            "astream_chat",
+            "astream_complete",
+        ):
+            if callable(method := cls.__dict__.get(attr)):
+                setattr(cls, attr, dispatcher.span(method))
diff --git a/llama-index-core/llama_index/core/llms/callbacks.py b/llama-index-core/llama_index/core/llms/callbacks.py
index 0e9ad8b837..44d7368f28 100644
--- a/llama-index-core/llama_index/core/llms/callbacks.py
+++ b/llama-index-core/llama_index/core/llms/callbacks.py
@@ -209,6 +209,22 @@ def llm_chat_callback() -> Callable:
         if not is_wrapped:
             f.__wrapped__ = True  # type: ignore
 
+        # Update the wrapper function to look like the wrapped function.
+        # See e.g. https://github.com/python/cpython/blob/0abf997e75bd3a8b76d920d33cc64d5e6c2d380f/Lib/functools.py#L57
+        for attr in (
+            "__module__",
+            "__name__",
+            "__qualname__",
+            "__doc__",
+            "__annotations__",
+            "__type_params__",
+        ):
+            if v := getattr(f, attr, None):
+                setattr(async_dummy_wrapper, attr, v)
+                setattr(wrapped_async_llm_chat, attr, v)
+                setattr(dummy_wrapper, attr, v)
+                setattr(wrapped_llm_chat, attr, v)
+
         if asyncio.iscoroutinefunction(f):
             if is_wrapped:
                 return async_dummy_wrapper
@@ -394,6 +410,22 @@ def llm_completion_callback() -> Callable:
         if not is_wrapped:
             f.__wrapped__ = True  # type: ignore
 
+        # Update the wrapper function to look like the wrapped function.
+        # See e.g. https://github.com/python/cpython/blob/0abf997e75bd3a8b76d920d33cc64d5e6c2d380f/Lib/functools.py#L57
+        for attr in (
+            "__module__",
+            "__name__",
+            "__qualname__",
+            "__doc__",
+            "__annotations__",
+            "__type_params__",
+        ):
+            if v := getattr(f, attr, None):
+                setattr(async_dummy_wrapper, attr, v)
+                setattr(wrapped_async_llm_predict, attr, v)
+                setattr(dummy_wrapper, attr, v)
+                setattr(wrapped_llm_predict, attr, v)
+
         if asyncio.iscoroutinefunction(f):
             if is_wrapped:
                 return async_dummy_wrapper
-- 
GitLab