Skip to content
Snippets Groups Projects
Unverified Commit 18ef7c4d authored by Logan's avatar Logan Committed by GitHub
Browse files

Memory/Chat Store improvements (#12394)

parent e8f8fa5d
Branches
Tags
No related merge requests found
...@@ -4,6 +4,14 @@ from typing import Any, AsyncGenerator, Generator, Optional, Union, List ...@@ -4,6 +4,14 @@ from typing import Any, AsyncGenerator, Generator, Optional, Union, List
from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
try:
from pydantic import BaseModel as V2BaseModel
from pydantic.v1 import BaseModel as V1BaseModel
except ImportError:
from pydantic import BaseModel as V2BaseModel
V1BaseModel = V2BaseModel
class MessageRole(str, Enum): class MessageRole(str, Enum):
"""Message role.""" """Message role."""
...@@ -39,6 +47,32 @@ class ChatMessage(BaseModel): ...@@ -39,6 +47,32 @@ class ChatMessage(BaseModel):
role = MessageRole(role) role = MessageRole(role)
return cls(role=role, content=content, **kwargs) return cls(role=role, content=content, **kwargs)
def _recursive_serialization(self, value: Any) -> Any:
if isinstance(value, (V1BaseModel, V2BaseModel)):
return value.dict()
if isinstance(value, dict):
return {
key: self._recursive_serialization(value)
for key, value in value.items()
}
if isinstance(value, list):
return [self._recursive_serialization(item) for item in value]
return value
def dict(self, **kwargs: Any) -> dict:
# ensure all additional_kwargs are serializable
msg = super().dict(**kwargs)
for key, value in msg["additional_kwargs"].items():
value = self._recursive_serialization(value)
if not isinstance(value, (str, int, float, bool, dict, list, type(None))):
raise ValueError(
f"Failed to serialize additional_kwargs value: {value}"
)
msg["additional_kwargs"][key] = value
return msg
class LogProb(BaseModel): class LogProb(BaseModel):
"""LogProb of a token.""" """LogProb of a token."""
......
...@@ -109,20 +109,26 @@ class ChatMemoryBuffer(BaseMemory): ...@@ -109,20 +109,26 @@ class ChatMemoryBuffer(BaseMemory):
raise ValueError("Initial token count exceeds token limit") raise ValueError("Initial token count exceeds token limit")
message_count = len(chat_history) message_count = len(chat_history)
token_count = (
self._token_count_for_message_count(message_count) + initial_token_count cur_messages = chat_history[-message_count:]
) token_count = self._token_count_for_messages(cur_messages) + initial_token_count
while token_count > self.token_limit and message_count > 1: while token_count > self.token_limit and message_count > 1:
message_count -= 1 message_count -= 1
if chat_history[-message_count].role == MessageRole.TOOL:
# all tool messages should be preceded by an assistant message
# if we remove a tool message, we need to remove the assistant message too
message_count -= 1
if chat_history[-message_count].role == MessageRole.ASSISTANT: if chat_history[-message_count].role == MessageRole.ASSISTANT:
# we cannot have an assistant message at the start of the chat history # we cannot have an assistant message at the start of the chat history
# if after removal of the first, we have an assistant message, # if after removal of the first, we have an assistant message,
# we need to remove the assistant message too # we need to remove the assistant message too
message_count -= 1 message_count -= 1
cur_messages = chat_history[-message_count:]
token_count = ( token_count = (
self._token_count_for_message_count(message_count) + initial_token_count self._token_count_for_messages(cur_messages) + initial_token_count
) )
# catch one message longer than token limit # catch one message longer than token limit
...@@ -137,6 +143,7 @@ class ChatMemoryBuffer(BaseMemory): ...@@ -137,6 +143,7 @@ class ChatMemoryBuffer(BaseMemory):
def put(self, message: ChatMessage) -> None: def put(self, message: ChatMessage) -> None:
"""Put chat history.""" """Put chat history."""
# ensure everything is serialized
self.chat_store.add_message(self.chat_store_key, message) self.chat_store.add_message(self.chat_store_key, message)
def set(self, messages: List[ChatMessage]) -> None: def set(self, messages: List[ChatMessage]) -> None:
...@@ -147,10 +154,9 @@ class ChatMemoryBuffer(BaseMemory): ...@@ -147,10 +154,9 @@ class ChatMemoryBuffer(BaseMemory):
"""Reset chat history.""" """Reset chat history."""
self.chat_store.delete_messages(self.chat_store_key) self.chat_store.delete_messages(self.chat_store_key)
def _token_count_for_message_count(self, message_count: int) -> int: def _token_count_for_messages(self, messages: List[ChatMessage]) -> int:
if message_count <= 0: if len(messages) <= 0:
return 0 return 0
chat_history = self.get_all() msg_str = " ".join(str(m.content) for m in messages)
msg_str = " ".join(str(m.content) for m in chat_history[-message_count:])
return len(self.tokenizer_fn(msg_str)) return len(self.tokenizer_fn(msg_str))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment