From a9ff814fe55a832230820ee70032918e94ad15bb Mon Sep 17 00:00:00 2001 From: Sherif Abdekarim <sherif.abdelkarim91@gmail.com> Date: Wed, 20 Mar 2024 20:06:13 +0200 Subject: [PATCH] Create lazy initialization for async elements in StreamingAgentChatResponse (#12116) --- .../llama_index/core/chat_engine/types.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index b92252870d..a4a0e09030 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -72,17 +72,17 @@ class StreamingAgentChatResponse: source_nodes: List[NodeWithScore] = field(default_factory=list) _unformatted_response: str = "" _queue: queue.Queue = field(default_factory=queue.Queue) - _aqueue: asyncio.Queue = field(default_factory=asyncio.Queue) + _aqueue: Optional[asyncio.Queue] = None # flag when chat message is a function call _is_function: Optional[bool] = None # flag when processing done _is_done = False # signal when a new item is added to the queue - _new_item_event: asyncio.Event = field(default_factory=asyncio.Event) + _new_item_event: Optional[asyncio.Event] = None # NOTE: async code uses two events rather than one since it yields # control when waiting for queue item # signal when the OpenAI functions stop executing - _is_function_false_event: asyncio.Event = field(default_factory=asyncio.Event) + _is_function_false_event: Optional[asyncio.Event] = None # signal when an OpenAI function is being executed _is_function_not_none_thread_event: Event = field(default_factory=Event) @@ -100,11 +100,20 @@ class StreamingAgentChatResponse: self.response = self._unformatted_response.strip() return self.response + def _ensure_async_setup(self) -> None: + if self._aqueue is None: + self._aqueue = asyncio.Queue() + if self._new_item_event is None: + self._new_item_event = asyncio.Event() + if self._is_function_false_event is None: + self._is_function_false_event = asyncio.Event() + def put_in_queue(self, delta: Optional[str]) -> None: self._queue.put_nowait(delta) self._is_function_not_none_thread_event.set() def aput_in_queue(self, delta: Optional[str]) -> None: + self._ensure_async_setup() self._aqueue.put_nowait(delta) self._new_item_event.set() @@ -207,6 +216,7 @@ class StreamingAgentChatResponse: self.response = self._unformatted_response.strip() async def async_response_gen(self) -> AsyncGenerator[str, None]: + self._ensure_async_setup() while True: if not self._aqueue.empty() or not self._is_done: try: -- GitLab