diff --git a/llama_index/ingestion/pipeline.py b/llama_index/ingestion/pipeline.py index 52e7bbf6118636eaabb74dc2610bdba85e0dccdd..6e84cc4b3d053a6a4b1e4835163961cc1a1b2ca7 100644 --- a/llama_index/ingestion/pipeline.py +++ b/llama_index/ingestion/pipeline.py @@ -478,6 +478,76 @@ class IngestionPipeline(BaseModel): return nodes + # ------ async methods ------ + + async def _ahandle_duplicates( + self, + nodes: List[BaseNode], + store_doc_text: bool = True, + ) -> List[BaseNode]: + """Handle docstore duplicates by checking all hashes.""" + assert self.docstore is not None + + existing_hashes = await self.docstore.aget_all_document_hashes() + current_hashes = [] + nodes_to_run = [] + for node in nodes: + if node.hash not in existing_hashes and node.hash not in current_hashes: + await self.docstore.aset_document_hash(node.id_, node.hash) + nodes_to_run.append(node) + current_hashes.append(node.hash) + + await self.docstore.async_add_documents(nodes_to_run, store_text=store_doc_text) + + return nodes_to_run + + async def _ahandle_upserts( + self, + nodes: List[BaseNode], + store_doc_text: bool = True, + ) -> List[BaseNode]: + """Handle docstore upserts by checking hashes and ids.""" + assert self.docstore is not None + + existing_doc_ids_before = set( + (await self.docstore.aget_all_document_hashes()).values() + ) + doc_ids_from_nodes = set() + deduped_nodes_to_run = {} + for node in nodes: + ref_doc_id = node.ref_doc_id if node.ref_doc_id else node.id_ + doc_ids_from_nodes.add(ref_doc_id) + existing_hash = await self.docstore.aget_document_hash(ref_doc_id) + if not existing_hash: + # document doesn't exist, so add it + await self.docstore.aset_document_hash(ref_doc_id, node.hash) + deduped_nodes_to_run[ref_doc_id] = node + elif existing_hash and existing_hash != node.hash: + await self.docstore.adelete_ref_doc(ref_doc_id, raise_error=False) + + if self.vector_store is not None: + await self.vector_store.adelete(ref_doc_id) + + await self.docstore.aset_document_hash(ref_doc_id, node.hash) + + deduped_nodes_to_run[ref_doc_id] = node + else: + continue # document exists and is unchanged, so skip it + + if self.docstore_strategy == DocstoreStrategy.UPSERTS_AND_DELETE: + # Identify missing docs and delete them from docstore and vector store + doc_ids_to_delete = existing_doc_ids_before - doc_ids_from_nodes + for ref_doc_id in doc_ids_to_delete: + await self.docstore.adelete_document(ref_doc_id) + + if self.vector_store is not None: + await self.vector_store.adelete(ref_doc_id) + + nodes_to_run = list(deduped_nodes_to_run.values()) + await self.docstore.async_add_documents(nodes_to_run, store_text=store_doc_text) + + return nodes_to_run + async def arun( self, show_progress: bool = False, @@ -497,11 +567,11 @@ class IngestionPipeline(BaseModel): DocstoreStrategy.UPSERTS, DocstoreStrategy.UPSERTS_AND_DELETE, ): - nodes_to_run = self._handle_upserts( + nodes_to_run = await self._ahandle_upserts( input_nodes, store_doc_text=store_doc_text ) elif self.docstore_strategy == DocstoreStrategy.DUPLICATES_ONLY: - nodes_to_run = self._handle_duplicates( + nodes_to_run = await self._ahandle_duplicates( input_nodes, store_doc_text=store_doc_text ) else: @@ -519,7 +589,7 @@ class IngestionPipeline(BaseModel): "Switching to duplicates_only strategy." ) self.docstore_strategy = DocstoreStrategy.DUPLICATES_ONLY - nodes_to_run = self._handle_duplicates( + nodes_to_run = await self._ahandle_duplicates( input_nodes, store_doc_text=store_doc_text ) diff --git a/llama_index/storage/kvstore/mongodb_kvstore.py b/llama_index/storage/kvstore/mongodb_kvstore.py index 8c5ac2ac5b91fef4100fafb461b798148b137059..81fd758b97794f820b26e27c0aa0c1f4a86307b1 100644 --- a/llama_index/storage/kvstore/mongodb_kvstore.py +++ b/llama_index/storage/kvstore/mongodb_kvstore.py @@ -6,7 +6,9 @@ from llama_index.storage.kvstore.types import ( BaseKVStore, ) -IMPORT_ERROR_MSG = "`pymongo` package not found, please run `pip install pymongo`" +IMPORT_ERROR_MSG = ( + "`pymongo` or `motor` package not found, please run `pip install pymongo motor`" +) class MongoDBKVStore(BaseKVStore): @@ -24,6 +26,7 @@ class MongoDBKVStore(BaseKVStore): def __init__( self, mongo_client: Any, + mongo_aclient: Optional[Any] = None, uri: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None, @@ -31,11 +34,15 @@ class MongoDBKVStore(BaseKVStore): ) -> None: """Init a MongoDBKVStore.""" try: + from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient except ImportError: raise ImportError(IMPORT_ERROR_MSG) self._client = cast(MongoClient, mongo_client) + self._aclient = ( + cast(AsyncIOMotorClient, mongo_aclient) if mongo_aclient else None + ) self._uri = uri self._host = host @@ -43,6 +50,7 @@ class MongoDBKVStore(BaseKVStore): self._db_name = db_name or "db_docstore" self._db = self._client[self._db_name] + self._adb = self._aclient[self._db_name] if self._aclient else None @classmethod def from_uri( @@ -58,13 +66,16 @@ class MongoDBKVStore(BaseKVStore): """ try: + from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient except ImportError: raise ImportError(IMPORT_ERROR_MSG) mongo_client: MongoClient = MongoClient(uri) + mongo_aclient: AsyncIOMotorClient = AsyncIOMotorClient(uri) return cls( mongo_client=mongo_client, + mongo_aclient=mongo_aclient, db_name=db_name, uri=uri, ) @@ -85,18 +96,25 @@ class MongoDBKVStore(BaseKVStore): """ try: + from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient except ImportError: raise ImportError(IMPORT_ERROR_MSG) mongo_client: MongoClient = MongoClient(host, port) + mongo_aclient: AsyncIOMotorClient = AsyncIOMotorClient(host, port) return cls( mongo_client=mongo_client, + mongo_aclient=mongo_aclient, db_name=db_name, host=host, port=port, ) + def _check_async_client(self) -> None: + if self._adb is None: + raise ValueError("MongoDBKVStore was not initialized with an async client") + def put( self, key: str, @@ -111,13 +129,7 @@ class MongoDBKVStore(BaseKVStore): collection (str): collection name """ - val = val.copy() - val["_id"] = key - self._db[collection].replace_one( - {"_id": key}, - val, - upsert=True, - ) + self.put_all([(key, val)], collection=collection) async def aput( self, @@ -133,7 +145,7 @@ class MongoDBKVStore(BaseKVStore): collection (str): collection name """ - raise NotImplementedError + await self.aput_all([(key, val)], collection=collection) def put_all( self, @@ -141,14 +153,47 @@ class MongoDBKVStore(BaseKVStore): collection: str = DEFAULT_COLLECTION, batch_size: int = DEFAULT_BATCH_SIZE, ) -> None: + from pymongo import UpdateOne + # Prepare documents with '_id' set to the key for batch insertion + docs = [{"_id": key, **value} for key, value in kv_pairs] + + # Insert documents in batches + for batch in ( + docs[i : i + batch_size] for i in range(0, len(docs), batch_size) + ): + new_docs = [] + for doc in batch: + new_docs.append( + UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=True) + ) + + self._db[collection].bulk_write(new_docs) + + async def aput_all( + self, + kv_pairs: List[Tuple[str, dict]], + collection: str = DEFAULT_COLLECTION, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + from pymongo import UpdateOne + + self._check_async_client() + # Prepare documents with '_id' set to the key for batch insertion docs = [{"_id": key, **value} for key, value in kv_pairs] + # Insert documents in batches for batch in ( docs[i : i + batch_size] for i in range(0, len(docs), batch_size) ): - self._db[collection].insert_many(batch) + new_docs = [] + for doc in batch: + new_docs.append( + UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=True) + ) + + await self._adb[collection].bulk_write(new_docs) def get(self, key: str, collection: str = DEFAULT_COLLECTION) -> Optional[dict]: """Get a value from the store. @@ -174,7 +219,13 @@ class MongoDBKVStore(BaseKVStore): collection (str): collection name """ - raise NotImplementedError + self._check_async_client() + + result = await self._adb[collection].find_one({"_id": key}) + if result is not None: + result.pop("_id") + return result + return None def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]: """Get all values from the store. @@ -197,7 +248,14 @@ class MongoDBKVStore(BaseKVStore): collection (str): collection name """ - raise NotImplementedError + self._check_async_client() + + results = self._adb[collection].find() + output = {} + for result in await results.to_list(length=None): + key = result.pop("_id") + output[key] = result + return output def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool: """Delete a value from the store. @@ -218,4 +276,7 @@ class MongoDBKVStore(BaseKVStore): collection (str): collection name """ - raise NotImplementedError + self._check_async_client() + + result = await self._adb[collection].delete_one({"_id": key}) + return result.deleted_count > 0 diff --git a/poetry.lock b/poetry.lock index cc82e3f28534539891485f7f6a28d2a050f30c11..63aab622b902f97b42ded6cb46afc98186226833 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -3053,6 +3053,30 @@ files = [ {file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"}, ] +[[package]] +name = "motor" +version = "3.3.2" +description = "Non-blocking MongoDB driver for Tornado or asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "motor-3.3.2-py3-none-any.whl", hash = "sha256:6fe7e6f0c4f430b9e030b9d22549b732f7c2226af3ab71ecc309e4a1b7d19953"}, + {file = "motor-3.3.2.tar.gz", hash = "sha256:d2fc38de15f1c8058f389c1a44a4d4105c0405c48c061cd492a654496f7bc26a"}, +] + +[package.dependencies] +pymongo = ">=4.5,<5" + +[package.extras] +aws = ["pymongo[aws] (>=4.5,<5)"] +encryption = ["pymongo[encryption] (>=4.5,<5)"] +gssapi = ["pymongo[gssapi] (>=4.5,<5)"] +ocsp = ["pymongo[ocsp] (>=4.5,<5)"] +snappy = ["pymongo[snappy] (>=4.5,<5)"] +srv = ["pymongo[srv] (>=4.5,<5)"] +test = ["aiohttp (<3.8.6)", "mockupdb", "motor[encryption]", "pytest (>=7)", "tornado (>=5)"] +zstd = ["pymongo[zstd] (>=4.5,<5)"] + [[package]] name = "mpmath" version = "1.3.0" @@ -3842,7 +3866,7 @@ files = [ [package.dependencies] coloredlogs = "*" datasets = [ - {version = "*"}, + {version = "*", optional = true, markers = "extra != \"onnxruntime\""}, {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""}, ] evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} @@ -4337,8 +4361,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -4621,7 +4643,6 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, - {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -7755,4 +7776,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "daf56947ebd15f8a3114e30570ad505419ead8478f18ee924d53ddf66018dee4" +content-hash = "ef273f99621199d1821deb1bbd567ffb6ba49d8afb4bcb4350c59c6d62ff5427" diff --git a/pyproject.toml b/pyproject.toml index 92a756eb04048b8fca9f3d4e8d3e7a88ba8112af..218d985026fd3033f48c1bab691f5cd3ef4ccbf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ codespell = {extras = ["toml"], version = ">=v2.2.6"} google-ai-generativelanguage = {python = ">=3.9,<3.12", version = "^0.4.0"} ipython = "8.10.0" jupyter = "^1.0.0" +motor = "^3.3.2" mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" diff --git a/tests/storage/kvstore/mock_mongodb.py b/tests/storage/kvstore/mock_mongodb.py index 9c06887570acbb767f574ba8fdcf4411b06bae8c..b98b7989f87deb12d0a7bc250997da83a664ef82 100644 --- a/tests/storage/kvstore/mock_mongodb.py +++ b/tests/storage/kvstore/mock_mongodb.py @@ -66,6 +66,12 @@ class MockMongoCollection: insert_result.inserted_ids = inserted_ids return insert_result + def bulk_write(self, operations: List[Any]) -> Any: + for operation in operations: + obj = operation._doc["$set"] + _id = obj.pop("_id") + self.insert_one(obj, _id) + class MockMongoDB: def __init__(self) -> None: