diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index 3ee9adf085084a5b5cb619eafc75c894557b81d5..567bd8a36e0906c4cdf4eba9b725a8975b008dc9 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -4,6 +4,14 @@ from typing import Any, AsyncGenerator, Generator, Optional, Union, List from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +try: + from pydantic import BaseModel as V2BaseModel + from pydantic.v1 import BaseModel as V1BaseModel +except ImportError: + from pydantic import BaseModel as V2BaseModel + + V1BaseModel = V2BaseModel + class MessageRole(str, Enum): """Message role.""" @@ -39,6 +47,32 @@ class ChatMessage(BaseModel): role = MessageRole(role) return cls(role=role, content=content, **kwargs) + def _recursive_serialization(self, value: Any) -> Any: + if isinstance(value, (V1BaseModel, V2BaseModel)): + return value.dict() + if isinstance(value, dict): + return { + key: self._recursive_serialization(value) + for key, value in value.items() + } + if isinstance(value, list): + return [self._recursive_serialization(item) for item in value] + return value + + def dict(self, **kwargs: Any) -> dict: + # ensure all additional_kwargs are serializable + msg = super().dict(**kwargs) + + for key, value in msg["additional_kwargs"].items(): + value = self._recursive_serialization(value) + if not isinstance(value, (str, int, float, bool, dict, list, type(None))): + raise ValueError( + f"Failed to serialize additional_kwargs value: {value}" + ) + msg["additional_kwargs"][key] = value + + return msg + class LogProb(BaseModel): """LogProb of a token.""" diff --git a/llama-index-core/llama_index/core/memory/chat_memory_buffer.py b/llama-index-core/llama_index/core/memory/chat_memory_buffer.py index 5c64af664d239c6007290a5c4b37ae2ed2bc45c2..2415f30e6be81349fc4433a797c0e94699fe4a28 100644 --- a/llama-index-core/llama_index/core/memory/chat_memory_buffer.py +++ b/llama-index-core/llama_index/core/memory/chat_memory_buffer.py @@ -109,20 +109,26 @@ class ChatMemoryBuffer(BaseMemory): raise ValueError("Initial token count exceeds token limit") message_count = len(chat_history) - token_count = ( - self._token_count_for_message_count(message_count) + initial_token_count - ) + + cur_messages = chat_history[-message_count:] + token_count = self._token_count_for_messages(cur_messages) + initial_token_count while token_count > self.token_limit and message_count > 1: message_count -= 1 + if chat_history[-message_count].role == MessageRole.TOOL: + # all tool messages should be preceded by an assistant message + # if we remove a tool message, we need to remove the assistant message too + message_count -= 1 + if chat_history[-message_count].role == MessageRole.ASSISTANT: # we cannot have an assistant message at the start of the chat history # if after removal of the first, we have an assistant message, # we need to remove the assistant message too message_count -= 1 + cur_messages = chat_history[-message_count:] token_count = ( - self._token_count_for_message_count(message_count) + initial_token_count + self._token_count_for_messages(cur_messages) + initial_token_count ) # catch one message longer than token limit @@ -137,6 +143,7 @@ class ChatMemoryBuffer(BaseMemory): def put(self, message: ChatMessage) -> None: """Put chat history.""" + # ensure everything is serialized self.chat_store.add_message(self.chat_store_key, message) def set(self, messages: List[ChatMessage]) -> None: @@ -147,10 +154,9 @@ class ChatMemoryBuffer(BaseMemory): """Reset chat history.""" self.chat_store.delete_messages(self.chat_store_key) - def _token_count_for_message_count(self, message_count: int) -> int: - if message_count <= 0: + def _token_count_for_messages(self, messages: List[ChatMessage]) -> int: + if len(messages) <= 0: return 0 - chat_history = self.get_all() - msg_str = " ".join(str(m.content) for m in chat_history[-message_count:]) + msg_str = " ".join(str(m.content) for m in messages) return len(self.tokenizer_fn(msg_str))