Skip to content
Snippets Groups Projects
Unverified Commit 2c63e789 authored by Logan's avatar Logan Committed by GitHub
Browse files

patch nested traces of the same type (#6791)

parent 84a9a4a6
Branches
Tags
No related merge requests found
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
### Bug Fixes / Nits ### Bug Fixes / Nits
- fixed passing in query bundle to node postprocessors (#6780) - fixed passing in query bundle to node postprocessors (#6780)
- fixed error in callback manager with nested traces (#6791)
## [v0.7.3] - 2023-07-07 ## [v0.7.3] - 2023-07-07
......
...@@ -4,7 +4,7 @@ from collections import defaultdict ...@@ -4,7 +4,7 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Generator from typing import Any, Dict, List, Optional, Generator
from llama_index.callbacks.schema import CBEventType, LEAF_EVENTS, BASE_TRACE_ID from llama_index.callbacks.schema import CBEventType, LEAF_EVENTS, BASE_TRACE_EVENT
class BaseCallbackHandler(ABC): class BaseCallbackHandler(ABC):
...@@ -76,8 +76,8 @@ class CallbackManager(BaseCallbackHandler, ABC): ...@@ -76,8 +76,8 @@ class CallbackManager(BaseCallbackHandler, ABC):
"""Initialize the manager with a list of handlers.""" """Initialize the manager with a list of handlers."""
self.handlers = handlers self.handlers = handlers
self._trace_map: Dict[str, List[str]] = defaultdict(list) self._trace_map: Dict[str, List[str]] = defaultdict(list)
self._trace_stack: List[str] = [BASE_TRACE_ID] self._trace_event_stack: List[str] = [BASE_TRACE_EVENT]
self._trace_id: Optional[str] = None self._trace_id_stack: List[str] = []
def on_event_start( def on_event_start(
self, self,
...@@ -88,13 +88,13 @@ class CallbackManager(BaseCallbackHandler, ABC): ...@@ -88,13 +88,13 @@ class CallbackManager(BaseCallbackHandler, ABC):
) -> str: ) -> str:
"""Run handlers when an event starts and return id of event.""" """Run handlers when an event starts and return id of event."""
event_id = event_id or str(uuid.uuid4()) event_id = event_id or str(uuid.uuid4())
self._trace_map[self._trace_stack[-1]].append(event_id) self._trace_map[self._trace_event_stack[-1]].append(event_id)
for handler in self.handlers: for handler in self.handlers:
if event_type not in handler.event_starts_to_ignore: if event_type not in handler.event_starts_to_ignore:
handler.on_event_start(event_type, payload, event_id=event_id, **kwargs) handler.on_event_start(event_type, payload, event_id=event_id, **kwargs)
if event_type not in LEAF_EVENTS: if event_type not in LEAF_EVENTS:
self._trace_stack.append(event_id) self._trace_event_stack.append(event_id)
return event_id return event_id
...@@ -112,7 +112,7 @@ class CallbackManager(BaseCallbackHandler, ABC): ...@@ -112,7 +112,7 @@ class CallbackManager(BaseCallbackHandler, ABC):
handler.on_event_end(event_type, payload, event_id=event_id, **kwargs) handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
if event_type not in LEAF_EVENTS: if event_type not in LEAF_EVENTS:
self._trace_stack.pop() self._trace_event_stack.pop()
def add_handler(self, handler: BaseCallbackHandler) -> None: def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager."""
...@@ -135,13 +135,16 @@ class CallbackManager(BaseCallbackHandler, ABC): ...@@ -135,13 +135,16 @@ class CallbackManager(BaseCallbackHandler, ABC):
def start_trace(self, trace_id: Optional[str] = None) -> None: def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched.""" """Run when an overall trace is launched."""
if not self._trace_id: if trace_id is not None:
self._reset_trace_events() if len(self._trace_id_stack) == 0:
self._reset_trace_events()
for handler in self.handlers: for handler in self.handlers:
handler.start_trace(trace_id=trace_id) handler.start_trace(trace_id=trace_id)
self._trace_id = trace_id self._trace_id_stack = [trace_id]
else:
self._trace_id_stack.append(trace_id)
def end_trace( def end_trace(
self, self,
...@@ -149,15 +152,17 @@ class CallbackManager(BaseCallbackHandler, ABC): ...@@ -149,15 +152,17 @@ class CallbackManager(BaseCallbackHandler, ABC):
trace_map: Optional[Dict[str, List[str]]] = None, trace_map: Optional[Dict[str, List[str]]] = None,
) -> None: ) -> None:
"""Run when an overall trace is exited.""" """Run when an overall trace is exited."""
if trace_id is not None and trace_id == self._trace_id: if trace_id is not None and len(self._trace_id_stack) > 0:
for handler in self.handlers: self._trace_id_stack.pop()
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map) if len(self._trace_id_stack) == 0:
self._trace_id = None for handler in self.handlers:
handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
self._trace_id_stack = []
def _reset_trace_events(self) -> None: def _reset_trace_events(self) -> None:
"""Helper function to reset the current trace.""" """Helper function to reset the current trace."""
self._trace_map = defaultdict(list) self._trace_map = defaultdict(list)
self._trace_stack = [BASE_TRACE_ID] self._trace_event_stack = [BASE_TRACE_EVENT]
@property @property
def trace_map(self) -> Dict[str, List[str]]: def trace_map(self) -> Dict[str, List[str]]:
......
...@@ -8,7 +8,7 @@ from llama_index.callbacks.schema import ( ...@@ -8,7 +8,7 @@ from llama_index.callbacks.schema import (
CBEventType, CBEventType,
EventStats, EventStats,
TIMESTAMP_FORMAT, TIMESTAMP_FORMAT,
BASE_TRACE_ID, BASE_TRACE_EVENT,
) )
...@@ -189,7 +189,7 @@ class LlamaDebugHandler(BaseCallbackHandler): ...@@ -189,7 +189,7 @@ class LlamaDebugHandler(BaseCallbackHandler):
"""Print simple trace map to terminal for debugging of the most recent trace.""" """Print simple trace map to terminal for debugging of the most recent trace."""
print("*" * 10, flush=True) print("*" * 10, flush=True)
print(f"Trace: {self._cur_trace_id}", flush=True) print(f"Trace: {self._cur_trace_id}", flush=True)
self._print_trace_map(BASE_TRACE_ID, level=1) self._print_trace_map(BASE_TRACE_EVENT, level=1)
print("*" * 10, flush=True) print("*" * 10, flush=True)
@property @property
......
...@@ -9,7 +9,7 @@ from typing import Any, Dict, Optional ...@@ -9,7 +9,7 @@ from typing import Any, Dict, Optional
TIMESTAMP_FORMAT = "%m/%d/%Y, %H:%M:%S.%f" TIMESTAMP_FORMAT = "%m/%d/%Y, %H:%M:%S.%f"
# base trace_id for the tracemap in callback_manager # base trace_id for the tracemap in callback_manager
BASE_TRACE_ID = "root" BASE_TRACE_EVENT = "root"
class CBEventType(str, Enum): class CBEventType(str, Enum):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment