diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/.gitignore b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..990c18de229088f55c6c514fd0f2d49981d1b0e7 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/BUILD b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0896ca890d8bffd60a44fa824f8d57fecd73ee53 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/LICENSE b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e6074042d0391061b7d7d45d21e6693ae108daec --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/LICENSE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) Jerry Liu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/Makefile b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b9eab05aa370629a4a3de75df3ff64cd53887b68 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/README.md b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc20c79a38bb2081d46b8b83e98a5e4a20c1eba7 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/README.md @@ -0,0 +1,69 @@ +# LlamaIndex Chat Store Integration: MongoDB Chat Store + +## Installation + +```bash +pip install llama-index-storage-chat-store-mongodb +``` + +## Usage + +Using `MongoChatStore` from `llama_index.storage.chat_store.mongo` +you can store chat history in MongoDB. + +```python +from llama_index.storage.chat_store.mongo import MongoChatStore + +# Initialize the MongoDB chat store with URI and database name and collection name +chat_store = MongoChatStore( + mongodb_uri="mongodb://localhost:27017/", + db_name="llama_index", + collection_name="chat_sessions", +) +``` + +You can also initialize the chat store with a `MongoClient` or `AsyncIOMotorClient` and a database name and collection name. + +```python +from pymongo import MongoClient +from motor.motor_asyncio import AsyncIOMotorClient + +client = MongoClient("mongodb://localhost:27017/") +async_client = AsyncIOMotorClient("mongodb://localhost:27017/") + +chat_store = MongoChatStore( + client=client, + amongo_client=async_client, + db_name="llama_index", + collection_name="chat_sessions", +) +``` + +You can also initialize the chat store with a `Collection` or `AsyncIOMotorCollection`. + +```python +from pymongo import Collection +from motor.motor_asyncio import AsyncIOMotorCollection + +client = MongoClient("mongodb://localhost:27017/") +async_client = AsyncIOMotorClient("mongodb://localhost:27017/") + +collection = client["llama_index"]["chat_sessions"] +async_collection = async_client["llama_index"]["chat_sessions"] + +chat_store = MongoChatStore( + collection=collection, async_collection=async_collection +) +``` + +## Usage with LlamaIndex + +```python +from llama_index.core.chat_engine.types import ChatMessage + +chat_memory = ChatMemoryBuffer.from_defaults( + token_limit=3000, + chat_store=chat_store, + chat_store_key="user1", +) +``` diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/BUILD b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db46e8d6c978c67e301dd6c47bee08c1b3fd141c --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/__init__.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89a50699c7554be799afbbee6032473a1b9f0da8 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/__init__.py @@ -0,0 +1,3 @@ +from llama_index.storage.chat_store.mongo.base import MongoChatStore + +__all__ = ["MongoChatStore"] diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/base.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e4ce77c884cc01b109d05d94ea36c919e04bba --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/llama_index/storage/chat_store/mongo/base.py @@ -0,0 +1,351 @@ +from typing import Any, List, Optional +from datetime import datetime + +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from pymongo import MongoClient +from pymongo.collection import Collection +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCollection + + +def _message_to_dict(message: ChatMessage) -> dict: + """Convert a ChatMessage to a dictionary for MongoDB storage.""" + return message.model_dump() + + +def _dict_to_message(d: dict) -> ChatMessage: + """Convert a dictionary from MongoDB to a ChatMessage.""" + return ChatMessage.model_validate(d) + + +class MongoChatStore(BaseChatStore): + """MongoDB chat store implementation.""" + + mongo_uri: str = Field( + default="mongodb://localhost:27017", description="MongoDB URI." + ) + db_name: str = Field(default="default", description="MongoDB database name.") + collection_name: str = Field( + default="sessions", description="MongoDB collection name." + ) + ttl_seconds: Optional[int] = Field( + default=None, description="Time to live in seconds." + ) + _mongo_client: Optional[MongoClient] = PrivateAttr() + _async_client: Optional[AsyncIOMotorClient] = PrivateAttr() + + def __init__( + self, + mongo_uri: str = "mongodb://localhost:27017", + db_name: str = "default", + collection_name: str = "sessions", + mongo_client: Optional[MongoClient] = None, + amongo_client: Optional[AsyncIOMotorClient] = None, + ttl_seconds: Optional[int] = None, + collection: Optional[Collection] = None, + async_collection: Optional[AsyncIOMotorCollection] = None, + **kwargs: Any, + ) -> None: + """ + Initialize the MongoDB chat store. + + Args: + mongo_uri: MongoDB connection URI + db_name: Database name + collection_name: Collection name for storing chat messages + mongo_client: Optional pre-configured MongoDB client + amongo_client: Optional pre-configured async MongoDB client + ttl_seconds: Optional time-to-live for messages in seconds + **kwargs: Additional arguments to pass to MongoDB client + """ + super().__init__(ttl=ttl_seconds) + + self._mongo_client = mongo_client or MongoClient(mongo_uri, **kwargs) + self._async_client = amongo_client or AsyncIOMotorClient(mongo_uri, **kwargs) + + if collection: + self._collection = collection + else: + self._collection = self._mongo_client[db_name][collection_name] + + if async_collection: + self._async_collection = async_collection + else: + self._async_collection = self._async_client[db_name][collection_name] + + # Create TTL index if ttl is specified + if ttl_seconds: + self._collection.create_index("created_at", expireAfterSeconds=ttl_seconds) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "MongoChatStore" + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """ + Set messages for a key. + + Args: + key: Key to set messages for + messages: List of ChatMessage objects + """ + # Delete existing messages for this key + self._collection.delete_many({"session_id": key}) + + # Insert new messages + if messages: + current_time = datetime.now() + message_dicts = [ + { + "session_id": key, + "index": i, + "message": _message_to_dict(msg), + "created_at": current_time, + } + for i, msg in enumerate(messages) + ] + self._collection.insert_many(message_dicts) + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """ + Set messages for a key asynchronously. + + Args: + key: Key to set messages for + messages: List of ChatMessage objects + """ + # Delete existing messages for this key + await self._async_collection.delete_many({"session_id": key}) + + # Insert new messages + if messages: + current_time = datetime.now() + message_dicts = [ + { + "session_id": key, + "index": i, + "message": _message_to_dict(msg), + "created_at": current_time, + } + for i, msg in enumerate(messages) + ] + await self._async_collection.insert_many(message_dicts) + + def get_messages(self, key: str) -> List[ChatMessage]: + """ + Get messages for a key. + + Args: + key: Key to get messages for + """ + # Find all messages for this key, sorted by index + docs = list(self._collection.find({"session_id": key}, sort=[("index", 1)])) + + # Convert to ChatMessage objects + return [_dict_to_message(doc["message"]) for doc in docs] + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """ + Get messages for a key asynchronously. + + Args: + key: Key to get messages for + """ + # Find all messages for this key, sorted by index + cursor = self._async_collection.find({"session_id": key}).sort("index", 1) + + # Convert to list and then to ChatMessage objects + docs = await cursor.to_list(length=None) + return [_dict_to_message(doc["message"]) for doc in docs] + + def add_message( + self, key: str, message: ChatMessage, idx: Optional[int] = None + ) -> None: + """ + Add a message for a key. + + Args: + key: Key to add message for + message: ChatMessage object to add + """ + if idx is None: + # Get the current highest index + highest_idx_doc = self._collection.find_one( + {"session_id": key}, sort=[("index", -1)] + ) + idx = 0 if highest_idx_doc is None else highest_idx_doc["index"] + 1 + + # Insert the new message with current timestamp + self._collection.insert_one( + { + "session_id": key, + "index": idx, + "message": _message_to_dict(message), + "created_at": datetime.now(), + } + ) + + async def async_add_message( + self, key: str, message: ChatMessage, idx: Optional[int] = None + ) -> None: + """ + Add a message for a key asynchronously. + + Args: + key: Key to add message for + message: ChatMessage object to add + """ + if idx is None: + # Get the current highest index + highest_idx_doc = await self._async_collection.find_one( + {"session_id": key}, sort=[("index", -1)] + ) + idx = 0 if highest_idx_doc is None else highest_idx_doc["index"] + 1 + + # Insert the new message with current timestamp + await self._async_collection.insert_one( + { + "session_id": key, + "index": idx, + "message": _message_to_dict(message), + "created_at": datetime.now(), + } + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """ + Delete messages for a key. + + Args: + key: Key to delete messages for + """ + # Get messages before deleting + messages = self.get_messages(key) + + # Delete all messages for this key + self._collection.delete_many({"session_id": key}) + + return messages + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """ + Delete messages for a key asynchronously. + + Args: + key: Key to delete messages for + """ + # Get messages before deleting + messages = await self.aget_messages(key) + + # Delete all messages for this key + await self._async_collection.delete_many({"session_id": key}) + + return messages + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """ + Delete specific message for a key. + + Args: + key: Key to delete message for + idx: Index of message to delete + """ + # Find the message to delete + doc = self._collection.find_one({"session_id": key, "index": idx}) + if doc is None: + return None + + # Delete the message + self._collection.delete_one({"session_id": key, "index": idx}) + + # Reindex remaining messages + self._collection.update_many( + {"session_id": key, "index": {"$gt": idx}}, {"$inc": {"index": -1}} + ) + + return _dict_to_message(doc["message"]) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """ + Delete specific message for a key asynchronously. + + Args: + key: Key to delete message for + idx: Index of message to delete + """ + # Find the message to delete + doc = await self._async_collection.find_one({"session_id": key, "index": idx}) + if doc is None: + return None + + # Delete the message + await self._async_collection.delete_one({"session_id": key, "index": idx}) + + # Reindex remaining messages + await self._async_collection.update_many( + {"session_id": key, "index": {"$gt": idx}}, {"$inc": {"index": -1}} + ) + + return _dict_to_message(doc["message"]) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """ + Delete last message for a key. + + Args: + key: Key to delete last message for + """ + # Find the last message + last_msg_doc = self._collection.find_one( + {"session_id": key}, sort=[("index", -1)] + ) + + if last_msg_doc is None: + return None + + # Delete the last message + self._collection.delete_one({"_id": last_msg_doc["_id"]}) + + return _dict_to_message(last_msg_doc["message"]) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """ + Delete last message for a key asynchronously. + + Args: + key: Key to delete last message for + """ + # Find the last message + last_msg_doc = await self._async_collection.find_one( + {"session_id": key}, sort=[("index", -1)] + ) + + if last_msg_doc is None: + return None + + # Delete the last message + await self._async_collection.delete_one({"_id": last_msg_doc["_id"]}) + + return _dict_to_message(last_msg_doc["message"]) + + def get_keys(self) -> List[str]: + """ + Get all keys (session IDs). + + Returns: + List of session IDs + """ + # Get distinct session IDs + return self._collection.distinct("session_id") + + async def aget_keys(self) -> List[str]: + """ + Get all keys (session IDs) asynchronously. + + Returns: + List of session IDs + """ + # Get distinct session IDs + return await self._async_collection.distinct("session_id") diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/pyproject.toml b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..5d7ed18b360373b9d5e0b3bb96c0b691fecb8ba6 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.storage.chat_store.mongo" + +[tool.llamahub.class_authors] +MongoChatStore = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Vrushab Ghodke <vrushab.ghodke@gmail.com>"] +description = "llama-index storage-chat-store mongo integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-storage-chat-store-mongo" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +llama-index-core = "^0.12.0" +pymongo = "^4.11.1" +motor = "^3.7.0" + +[tool.poetry.group.dev.dependencies] +docker = "^7.1.0" +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "^8.3.5" +pytest-asyncio = "^0.25.3" +pytest-cov = "^6.0.0" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/BUILD b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dabf212d7e7162849c24a733909ac4f645d75a31 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/__init__.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/test_chat_store_mongo_chat_store.py b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/test_chat_store_mongo_chat_store.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e1adcd6bd1cf24e48da6b6ae6466d439799788 --- /dev/null +++ b/llama-index-integrations/storage/chat_store/llama-index-storage-chat-store-mongo/tests/test_chat_store_mongo_chat_store.py @@ -0,0 +1,475 @@ +import time +import pytest +import docker +from typing import Dict, Generator, Union +from docker.models.containers import Container +from llama_index.core.storage.chat_store.base import BaseChatStore +from llama_index.storage.chat_store.mongo.base import MongoChatStore + +from llama_index.core.llms import ChatMessage + +try: + import pymongo # noqa: F401 + import motor.motor_asyncio # noqa: F401 + + no_packages = False +except ImportError: + no_packages = True + + +def test_class(): + """Test that MongoChatStore inherits from BaseChatStore.""" + names_of_base_classes = [b.__name__ for b in MongoChatStore.__mro__] + assert BaseChatStore.__name__ in names_of_base_classes + + +@pytest.fixture() +def mongo_container() -> Generator[Dict[str, Union[str, Container]], None, None]: + """Fixture to create a MongoDB container for testing.""" + # Define MongoDB settings + mongo_image = "mongo:latest" + mongo_ports = {"27017/tcp": 27017} + container = None + try: + # Initialize Docker client + client = docker.from_env() + + # Run MongoDB container + container = client.containers.run(mongo_image, ports=mongo_ports, detach=True) + + # Wait for MongoDB to start + time.sleep(5) # Give MongoDB time to initialize + + # Return connection information + yield { + "container": container, + "mongodb_uri": "mongodb://localhost:27017/", + } + finally: + # Stop and remove the container + if container: + container.stop() + container.remove() + client.close() + + +@pytest.fixture() +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def mongo_chat_store( + mongo_container: Dict[str, Union[str, Container]], +) -> Generator[MongoChatStore, None, None]: + """Fixture to create a MongoChatStore instance connected to the test container.""" + chat_store = None + try: + chat_store = MongoChatStore( + mongo_uri=mongo_container["mongodb_uri"], + db_name="test_db", + collection_name="test_chats", + ) + yield chat_store + finally: + if chat_store and hasattr(chat_store, "_collection"): + # Clean up by dropping the collection + chat_store._collection.drop() + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_create_chat(mongo_chat_store: MongoChatStore): + """Test creating a chat session.""" + # Create a chat with metadata + key = "test_key" + messages = [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm doing well, thank you!"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Get all keys and verify our chat_id is there + keys = mongo_chat_store.get_keys() + assert key in keys + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_get_messages(mongo_chat_store: MongoChatStore): + """Test retrieving messages from a chat session.""" + # Create a chat with messages + key = "test_get_messages" + messages = [ + ChatMessage(role="user", content="Hello, MongoDB!"), + ChatMessage(role="assistant", content="Hello, user! How can I help you?"), + ChatMessage(role="user", content="I need information about databases."), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Retrieve the messages + retrieved_messages = mongo_chat_store.get_messages(key=key) + + # Verify the messages were retrieved correctly + assert len(retrieved_messages) == 3 + assert retrieved_messages[0].role == "user" + assert retrieved_messages[0].content == "Hello, MongoDB!" + assert retrieved_messages[1].role == "assistant" + assert retrieved_messages[1].content == "Hello, user! How can I help you?" + assert retrieved_messages[2].role == "user" + assert retrieved_messages[2].content == "I need information about databases." + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_add_message(mongo_chat_store: MongoChatStore): + """Test adding a message to an existing chat session.""" + # Create a chat with initial messages + key = "test_add_message" + initial_messages = [ChatMessage(role="user", content="Initial message")] + mongo_chat_store.set_messages(key=key, messages=initial_messages) + + # Add a new message + new_message = ChatMessage(role="assistant", content="Response to initial message") + mongo_chat_store.add_message(key=key, message=new_message) + + # Retrieve all messages + messages = mongo_chat_store.get_messages(key=key) + + # Verify the new message was added + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[0].content == "Initial message" + assert messages[1].role == "assistant" + assert messages[1].content == "Response to initial message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_delete_messages(mongo_chat_store: MongoChatStore): + """Test deleting all messages from a chat session.""" + # Create a chat with messages + key = "test_delete_messages" + messages = [ + ChatMessage(role="user", content="Message to be deleted"), + ChatMessage(role="assistant", content="This will also be deleted"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete all messages + deleted_messages = mongo_chat_store.delete_messages(key=key) + + # Verify the messages were deleted + assert len(deleted_messages) == 2 + assert deleted_messages[0].content == "Message to be deleted" + + # Verify the chat is empty + remaining_messages = mongo_chat_store.get_messages(key=key) + assert len(remaining_messages) == 0 + + # Verify the key is not present in the store + keys = mongo_chat_store.get_keys() + assert key not in keys + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_delete_message(mongo_chat_store: MongoChatStore): + """Test deleting a specific message from a chat session.""" + # Create a chat with multiple messages + key = "test_delete_specific" + messages = [ + ChatMessage(role="user", content="First message"), + ChatMessage(role="assistant", content="Middle message to delete"), + ChatMessage(role="user", content="Last message"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete the middle message + deleted_message = mongo_chat_store.delete_message(key=key, idx=1) + + # Verify the correct message was deleted + assert deleted_message.role == "assistant" + assert deleted_message.content == "Middle message to delete" + + # Verify the remaining messages are correct and reindexed + remaining = mongo_chat_store.get_messages(key=key) + assert len(remaining) == 2 + assert remaining[0].content == "First message" + assert remaining[1].content == "Last message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_delete_last_message(mongo_chat_store: MongoChatStore): + """Test deleting the last message from a chat session.""" + # Create a chat with messages + key = "test_delete_last" + messages = [ + ChatMessage(role="user", content="First message"), + ChatMessage(role="assistant", content="Last message to delete"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete the last message + deleted = mongo_chat_store.delete_last_message(key=key) + + # Verify the correct message was deleted + assert deleted.role == "assistant" + assert deleted.content == "Last message to delete" + + # Verify only the first message remains + remaining = mongo_chat_store.get_messages(key=key) + assert len(remaining) == 1 + assert remaining[0].content == "First message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_get_messages(mongo_chat_store: MongoChatStore): + """Test retrieving messages asynchronously.""" + # Create a chat with messages + key = "test_async_get" + messages = [ + ChatMessage(role="user", content="Async test message"), + ChatMessage(role="assistant", content="Async response"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Retrieve messages asynchronously + retrieved = await mongo_chat_store.aget_messages(key=key) + + # Verify messages were retrieved correctly + assert len(retrieved) == 2 + assert retrieved[0].content == "Async test message" + assert retrieved[1].content == "Async response" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_add_message(mongo_chat_store: MongoChatStore): + """Test adding a message asynchronously.""" + key = "test_async_add" + initial_message = ChatMessage(role="user", content="Initial async message") + mongo_chat_store.set_messages(key=key, messages=[initial_message]) + + # Add message asynchronously + new_message = ChatMessage(role="assistant", content="Async response") + await mongo_chat_store.async_add_message(key=key, message=new_message) + + # Verify message was added + messages = mongo_chat_store.get_messages(key=key) + assert len(messages) == 2 + assert messages[1].content == "Async response" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_set_messages(mongo_chat_store: MongoChatStore): + """Test setting messages asynchronously.""" + key = "test_async_set" + messages = [ + ChatMessage(role="user", content="First async set message"), + ChatMessage(role="assistant", content="Second async set message"), + ] + + # Set messages asynchronously + await mongo_chat_store.aset_messages(key=key, messages=messages) + + # Verify messages were set correctly + retrieved = await mongo_chat_store.aget_messages(key=key) + assert len(retrieved) == 2 + assert retrieved[0].content == "First async set message" + assert retrieved[1].content == "Second async set message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_delete_messages(mongo_chat_store: MongoChatStore): + """Test deleting all messages asynchronously.""" + key = "test_async_delete_all" + messages = [ + ChatMessage(role="user", content="Async message to delete 1"), + ChatMessage(role="assistant", content="Async message to delete 2"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete messages asynchronously + deleted = await mongo_chat_store.adelete_messages(key=key) + + # Verify messages were deleted + assert len(deleted) == 2 + assert deleted[0].content == "Async message to delete 1" + + # Verify no messages remain + remaining = await mongo_chat_store.aget_messages(key=key) + assert len(remaining) == 0 + + # Verify key is not in store + keys = await mongo_chat_store.aget_keys() + assert key not in keys + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_delete_message(mongo_chat_store: MongoChatStore): + """Test deleting a specific message asynchronously.""" + key = "test_async_delete_specific" + messages = [ + ChatMessage(role="user", content="Async first message"), + ChatMessage(role="assistant", content="Async middle message to delete"), + ChatMessage(role="user", content="Async last message"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete middle message asynchronously + deleted = await mongo_chat_store.adelete_message(key=key, idx=1) + + # Verify correct message was deleted + assert deleted.role == "assistant" + assert deleted.content == "Async middle message to delete" + + # Verify remaining messages and reindexing + remaining = await mongo_chat_store.aget_messages(key=key) + assert len(remaining) == 2 + assert remaining[0].content == "Async first message" + assert remaining[1].content == "Async last message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_delete_last_message(mongo_chat_store: MongoChatStore): + """Test deleting the last message asynchronously.""" + key = "test_async_delete_last" + messages = [ + ChatMessage(role="user", content="Async first message"), + ChatMessage(role="assistant", content="Async last message to delete"), + ] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Delete last message asynchronously + deleted = await mongo_chat_store.adelete_last_message(key=key) + + # Verify correct message was deleted + assert deleted.role == "assistant" + assert deleted.content == "Async last message to delete" + + # Verify only first message remains + remaining = await mongo_chat_store.aget_messages(key=key) + assert len(remaining) == 1 + assert remaining[0].content == "Async first message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +@pytest.mark.asyncio() +async def test_async_get_keys(mongo_chat_store: MongoChatStore): + """Test getting all keys asynchronously.""" + # Create multiple chats + await mongo_chat_store.aset_messages( + key="async_keys_test1", + messages=[ChatMessage(role="user", content="Test message 1")], + ) + await mongo_chat_store.aset_messages( + key="async_keys_test2", + messages=[ChatMessage(role="user", content="Test message 2")], + ) + + # Get keys asynchronously + keys = await mongo_chat_store.aget_keys() + + # Verify keys were retrieved + assert "async_keys_test1" in keys + assert "async_keys_test2" in keys + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_nonexistent_key(mongo_chat_store: MongoChatStore): + """Test behavior with nonexistent keys.""" + # Try to get messages for nonexistent key + messages = mongo_chat_store.get_messages(key="nonexistent_key") + + # Verify empty list is returned + assert messages == [] + + # Try to delete a message from nonexistent chat + deleted = mongo_chat_store.delete_message(key="nonexistent_key", idx=0) + + # Verify None is returned + assert deleted is None + + # Try to delete last message from nonexistent chat + deleted = mongo_chat_store.delete_last_message(key="nonexistent_key") + + # Verify None is returned + assert deleted is None + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_ttl_configuration(mongo_container): + """Test TTL configuration is applied correctly.""" + # Create chat store with TTL + chat_store = MongoChatStore( + mongo_uri=mongo_container["mongodb_uri"], + db_name="test_ttl_db", + collection_name="test_ttl_chats", + ttl_seconds=3600, # 1 hour TTL + ) + + # Verify TTL index was created + indexes = list(chat_store._collection.list_indexes()) + ttl_index = next((idx for idx in indexes if "created_at" in idx["key"]), None) + + assert ttl_index is not None + assert ttl_index.get("expireAfterSeconds") == 3600 + + # Clean up + chat_store._collection.drop() + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_invalid_message_index(mongo_chat_store: MongoChatStore): + """Test behavior when trying to delete a message with invalid index.""" + key = "test_invalid_index" + messages = [ChatMessage(role="user", content="Only message")] + mongo_chat_store.set_messages(key=key, messages=messages) + + # Try to delete message with out-of-range index + deleted = mongo_chat_store.delete_message(key=key, idx=5) + + # Verify None is returned + assert deleted is None + + # Verify original message still exists + remaining = mongo_chat_store.get_messages(key=key) + assert len(remaining) == 1 + assert remaining[0].content == "Only message" + + +@pytest.mark.skipif(no_packages, reason="pymongo and motor not installed") +def test_multiple_clients(mongo_container): + """Test using multiple chat store instances with the same database.""" + # Create two chat store instances + chat_store1 = MongoChatStore( + mongo_uri=mongo_container["mongodb_uri"], + db_name="test_multi_client_db", + collection_name="test_chats", + ) + + chat_store2 = MongoChatStore( + mongo_uri=mongo_container["mongodb_uri"], + db_name="test_multi_client_db", + collection_name="test_chats", + ) + + # Add message with first client + key = "test_multi_client" + chat_store1.set_messages( + key=key, messages=[ChatMessage(role="user", content="Message from client 1")] + ) + + # Add message with second client + chat_store2.add_message( + key=key, message=ChatMessage(role="assistant", content="Message from client 2") + ) + + # Verify both messages are visible to both clients + messages1 = chat_store1.get_messages(key=key) + messages2 = chat_store2.get_messages(key=key) + + assert len(messages1) == 2 + assert len(messages2) == 2 + assert messages1[0].content == "Message from client 1" + assert messages1[1].content == "Message from client 2" + + # Clean up + chat_store1._collection.drop()