Skip to content
Snippets Groups Projects
Unverified Commit 24c075c2 authored by Logan's avatar Logan Committed by GitHub
Browse files

Add async mongodb, fix batch insert (#10081)

parent 8b614c08
Branches
Tags
No related merge requests found
...@@ -478,6 +478,76 @@ class IngestionPipeline(BaseModel): ...@@ -478,6 +478,76 @@ class IngestionPipeline(BaseModel):
return nodes return nodes
# ------ async methods ------
async def _ahandle_duplicates(
self,
nodes: List[BaseNode],
store_doc_text: bool = True,
) -> List[BaseNode]:
"""Handle docstore duplicates by checking all hashes."""
assert self.docstore is not None
existing_hashes = await self.docstore.aget_all_document_hashes()
current_hashes = []
nodes_to_run = []
for node in nodes:
if node.hash not in existing_hashes and node.hash not in current_hashes:
await self.docstore.aset_document_hash(node.id_, node.hash)
nodes_to_run.append(node)
current_hashes.append(node.hash)
await self.docstore.async_add_documents(nodes_to_run, store_text=store_doc_text)
return nodes_to_run
async def _ahandle_upserts(
self,
nodes: List[BaseNode],
store_doc_text: bool = True,
) -> List[BaseNode]:
"""Handle docstore upserts by checking hashes and ids."""
assert self.docstore is not None
existing_doc_ids_before = set(
(await self.docstore.aget_all_document_hashes()).values()
)
doc_ids_from_nodes = set()
deduped_nodes_to_run = {}
for node in nodes:
ref_doc_id = node.ref_doc_id if node.ref_doc_id else node.id_
doc_ids_from_nodes.add(ref_doc_id)
existing_hash = await self.docstore.aget_document_hash(ref_doc_id)
if not existing_hash:
# document doesn't exist, so add it
await self.docstore.aset_document_hash(ref_doc_id, node.hash)
deduped_nodes_to_run[ref_doc_id] = node
elif existing_hash and existing_hash != node.hash:
await self.docstore.adelete_ref_doc(ref_doc_id, raise_error=False)
if self.vector_store is not None:
await self.vector_store.adelete(ref_doc_id)
await self.docstore.aset_document_hash(ref_doc_id, node.hash)
deduped_nodes_to_run[ref_doc_id] = node
else:
continue # document exists and is unchanged, so skip it
if self.docstore_strategy == DocstoreStrategy.UPSERTS_AND_DELETE:
# Identify missing docs and delete them from docstore and vector store
doc_ids_to_delete = existing_doc_ids_before - doc_ids_from_nodes
for ref_doc_id in doc_ids_to_delete:
await self.docstore.adelete_document(ref_doc_id)
if self.vector_store is not None:
await self.vector_store.adelete(ref_doc_id)
nodes_to_run = list(deduped_nodes_to_run.values())
await self.docstore.async_add_documents(nodes_to_run, store_text=store_doc_text)
return nodes_to_run
async def arun( async def arun(
self, self,
show_progress: bool = False, show_progress: bool = False,
...@@ -497,11 +567,11 @@ class IngestionPipeline(BaseModel): ...@@ -497,11 +567,11 @@ class IngestionPipeline(BaseModel):
DocstoreStrategy.UPSERTS, DocstoreStrategy.UPSERTS,
DocstoreStrategy.UPSERTS_AND_DELETE, DocstoreStrategy.UPSERTS_AND_DELETE,
): ):
nodes_to_run = self._handle_upserts( nodes_to_run = await self._ahandle_upserts(
input_nodes, store_doc_text=store_doc_text input_nodes, store_doc_text=store_doc_text
) )
elif self.docstore_strategy == DocstoreStrategy.DUPLICATES_ONLY: elif self.docstore_strategy == DocstoreStrategy.DUPLICATES_ONLY:
nodes_to_run = self._handle_duplicates( nodes_to_run = await self._ahandle_duplicates(
input_nodes, store_doc_text=store_doc_text input_nodes, store_doc_text=store_doc_text
) )
else: else:
...@@ -519,7 +589,7 @@ class IngestionPipeline(BaseModel): ...@@ -519,7 +589,7 @@ class IngestionPipeline(BaseModel):
"Switching to duplicates_only strategy." "Switching to duplicates_only strategy."
) )
self.docstore_strategy = DocstoreStrategy.DUPLICATES_ONLY self.docstore_strategy = DocstoreStrategy.DUPLICATES_ONLY
nodes_to_run = self._handle_duplicates( nodes_to_run = await self._ahandle_duplicates(
input_nodes, store_doc_text=store_doc_text input_nodes, store_doc_text=store_doc_text
) )
......
...@@ -6,7 +6,9 @@ from llama_index.storage.kvstore.types import ( ...@@ -6,7 +6,9 @@ from llama_index.storage.kvstore.types import (
BaseKVStore, BaseKVStore,
) )
IMPORT_ERROR_MSG = "`pymongo` package not found, please run `pip install pymongo`" IMPORT_ERROR_MSG = (
"`pymongo` or `motor` package not found, please run `pip install pymongo motor`"
)
class MongoDBKVStore(BaseKVStore): class MongoDBKVStore(BaseKVStore):
...@@ -24,6 +26,7 @@ class MongoDBKVStore(BaseKVStore): ...@@ -24,6 +26,7 @@ class MongoDBKVStore(BaseKVStore):
def __init__( def __init__(
self, self,
mongo_client: Any, mongo_client: Any,
mongo_aclient: Optional[Any] = None,
uri: Optional[str] = None, uri: Optional[str] = None,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
...@@ -31,11 +34,15 @@ class MongoDBKVStore(BaseKVStore): ...@@ -31,11 +34,15 @@ class MongoDBKVStore(BaseKVStore):
) -> None: ) -> None:
"""Init a MongoDBKVStore.""" """Init a MongoDBKVStore."""
try: try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
except ImportError: except ImportError:
raise ImportError(IMPORT_ERROR_MSG) raise ImportError(IMPORT_ERROR_MSG)
self._client = cast(MongoClient, mongo_client) self._client = cast(MongoClient, mongo_client)
self._aclient = (
cast(AsyncIOMotorClient, mongo_aclient) if mongo_aclient else None
)
self._uri = uri self._uri = uri
self._host = host self._host = host
...@@ -43,6 +50,7 @@ class MongoDBKVStore(BaseKVStore): ...@@ -43,6 +50,7 @@ class MongoDBKVStore(BaseKVStore):
self._db_name = db_name or "db_docstore" self._db_name = db_name or "db_docstore"
self._db = self._client[self._db_name] self._db = self._client[self._db_name]
self._adb = self._aclient[self._db_name] if self._aclient else None
@classmethod @classmethod
def from_uri( def from_uri(
...@@ -58,13 +66,16 @@ class MongoDBKVStore(BaseKVStore): ...@@ -58,13 +66,16 @@ class MongoDBKVStore(BaseKVStore):
""" """
try: try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
except ImportError: except ImportError:
raise ImportError(IMPORT_ERROR_MSG) raise ImportError(IMPORT_ERROR_MSG)
mongo_client: MongoClient = MongoClient(uri) mongo_client: MongoClient = MongoClient(uri)
mongo_aclient: AsyncIOMotorClient = AsyncIOMotorClient(uri)
return cls( return cls(
mongo_client=mongo_client, mongo_client=mongo_client,
mongo_aclient=mongo_aclient,
db_name=db_name, db_name=db_name,
uri=uri, uri=uri,
) )
...@@ -85,18 +96,25 @@ class MongoDBKVStore(BaseKVStore): ...@@ -85,18 +96,25 @@ class MongoDBKVStore(BaseKVStore):
""" """
try: try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
except ImportError: except ImportError:
raise ImportError(IMPORT_ERROR_MSG) raise ImportError(IMPORT_ERROR_MSG)
mongo_client: MongoClient = MongoClient(host, port) mongo_client: MongoClient = MongoClient(host, port)
mongo_aclient: AsyncIOMotorClient = AsyncIOMotorClient(host, port)
return cls( return cls(
mongo_client=mongo_client, mongo_client=mongo_client,
mongo_aclient=mongo_aclient,
db_name=db_name, db_name=db_name,
host=host, host=host,
port=port, port=port,
) )
def _check_async_client(self) -> None:
if self._adb is None:
raise ValueError("MongoDBKVStore was not initialized with an async client")
def put( def put(
self, self,
key: str, key: str,
...@@ -111,13 +129,7 @@ class MongoDBKVStore(BaseKVStore): ...@@ -111,13 +129,7 @@ class MongoDBKVStore(BaseKVStore):
collection (str): collection name collection (str): collection name
""" """
val = val.copy() self.put_all([(key, val)], collection=collection)
val["_id"] = key
self._db[collection].replace_one(
{"_id": key},
val,
upsert=True,
)
async def aput( async def aput(
self, self,
...@@ -133,7 +145,7 @@ class MongoDBKVStore(BaseKVStore): ...@@ -133,7 +145,7 @@ class MongoDBKVStore(BaseKVStore):
collection (str): collection name collection (str): collection name
""" """
raise NotImplementedError await self.aput_all([(key, val)], collection=collection)
def put_all( def put_all(
self, self,
...@@ -141,14 +153,47 @@ class MongoDBKVStore(BaseKVStore): ...@@ -141,14 +153,47 @@ class MongoDBKVStore(BaseKVStore):
collection: str = DEFAULT_COLLECTION, collection: str = DEFAULT_COLLECTION,
batch_size: int = DEFAULT_BATCH_SIZE, batch_size: int = DEFAULT_BATCH_SIZE,
) -> None: ) -> None:
from pymongo import UpdateOne
# Prepare documents with '_id' set to the key for batch insertion # Prepare documents with '_id' set to the key for batch insertion
docs = [{"_id": key, **value} for key, value in kv_pairs]
# Insert documents in batches
for batch in (
docs[i : i + batch_size] for i in range(0, len(docs), batch_size)
):
new_docs = []
for doc in batch:
new_docs.append(
UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
)
self._db[collection].bulk_write(new_docs)
async def aput_all(
self,
kv_pairs: List[Tuple[str, dict]],
collection: str = DEFAULT_COLLECTION,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> None:
from pymongo import UpdateOne
self._check_async_client()
# Prepare documents with '_id' set to the key for batch insertion
docs = [{"_id": key, **value} for key, value in kv_pairs] docs = [{"_id": key, **value} for key, value in kv_pairs]
# Insert documents in batches # Insert documents in batches
for batch in ( for batch in (
docs[i : i + batch_size] for i in range(0, len(docs), batch_size) docs[i : i + batch_size] for i in range(0, len(docs), batch_size)
): ):
self._db[collection].insert_many(batch) new_docs = []
for doc in batch:
new_docs.append(
UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
)
await self._adb[collection].bulk_write(new_docs)
def get(self, key: str, collection: str = DEFAULT_COLLECTION) -> Optional[dict]: def get(self, key: str, collection: str = DEFAULT_COLLECTION) -> Optional[dict]:
"""Get a value from the store. """Get a value from the store.
...@@ -174,7 +219,13 @@ class MongoDBKVStore(BaseKVStore): ...@@ -174,7 +219,13 @@ class MongoDBKVStore(BaseKVStore):
collection (str): collection name collection (str): collection name
""" """
raise NotImplementedError self._check_async_client()
result = await self._adb[collection].find_one({"_id": key})
if result is not None:
result.pop("_id")
return result
return None
def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]: def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]:
"""Get all values from the store. """Get all values from the store.
...@@ -197,7 +248,14 @@ class MongoDBKVStore(BaseKVStore): ...@@ -197,7 +248,14 @@ class MongoDBKVStore(BaseKVStore):
collection (str): collection name collection (str): collection name
""" """
raise NotImplementedError self._check_async_client()
results = self._adb[collection].find()
output = {}
for result in await results.to_list(length=None):
key = result.pop("_id")
output[key] = result
return output
def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool: def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool:
"""Delete a value from the store. """Delete a value from the store.
...@@ -218,4 +276,7 @@ class MongoDBKVStore(BaseKVStore): ...@@ -218,4 +276,7 @@ class MongoDBKVStore(BaseKVStore):
collection (str): collection name collection (str): collection name
""" """
raise NotImplementedError self._check_async_client()
result = await self._adb[collection].delete_one({"_id": key})
return result.deleted_count > 0
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
   
[[package]] [[package]]
name = "accelerate" name = "accelerate"
...@@ -3053,6 +3053,30 @@ files = [ ...@@ -3053,6 +3053,30 @@ files = [
{file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"}, {file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"},
] ]
   
[[package]]
name = "motor"
version = "3.3.2"
description = "Non-blocking MongoDB driver for Tornado or asyncio"
optional = false
python-versions = ">=3.7"
files = [
{file = "motor-3.3.2-py3-none-any.whl", hash = "sha256:6fe7e6f0c4f430b9e030b9d22549b732f7c2226af3ab71ecc309e4a1b7d19953"},
{file = "motor-3.3.2.tar.gz", hash = "sha256:d2fc38de15f1c8058f389c1a44a4d4105c0405c48c061cd492a654496f7bc26a"},
]
[package.dependencies]
pymongo = ">=4.5,<5"
[package.extras]
aws = ["pymongo[aws] (>=4.5,<5)"]
encryption = ["pymongo[encryption] (>=4.5,<5)"]
gssapi = ["pymongo[gssapi] (>=4.5,<5)"]
ocsp = ["pymongo[ocsp] (>=4.5,<5)"]
snappy = ["pymongo[snappy] (>=4.5,<5)"]
srv = ["pymongo[srv] (>=4.5,<5)"]
test = ["aiohttp (<3.8.6)", "mockupdb", "motor[encryption]", "pytest (>=7)", "tornado (>=5)"]
zstd = ["pymongo[zstd] (>=4.5,<5)"]
[[package]] [[package]]
name = "mpmath" name = "mpmath"
version = "1.3.0" version = "1.3.0"
...@@ -3842,7 +3866,7 @@ files = [ ...@@ -3842,7 +3866,7 @@ files = [
[package.dependencies] [package.dependencies]
coloredlogs = "*" coloredlogs = "*"
datasets = [ datasets = [
{version = "*"}, {version = "*", optional = true, markers = "extra != \"onnxruntime\""},
{version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""}, {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""},
] ]
evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""}
...@@ -4337,8 +4361,6 @@ files = [ ...@@ -4337,8 +4361,6 @@ files = [
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
...@@ -4621,7 +4643,6 @@ files = [ ...@@ -4621,7 +4643,6 @@ files = [
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
...@@ -7755,4 +7776,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc ...@@ -7755,4 +7776,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "daf56947ebd15f8a3114e30570ad505419ead8478f18ee924d53ddf66018dee4" content-hash = "ef273f99621199d1821deb1bbd567ffb6ba49d8afb4bcb4350c59c6d62ff5427"
...@@ -114,6 +114,7 @@ codespell = {extras = ["toml"], version = ">=v2.2.6"} ...@@ -114,6 +114,7 @@ codespell = {extras = ["toml"], version = ">=v2.2.6"}
google-ai-generativelanguage = {python = ">=3.9,<3.12", version = "^0.4.0"} google-ai-generativelanguage = {python = ">=3.9,<3.12", version = "^0.4.0"}
ipython = "8.10.0" ipython = "8.10.0"
jupyter = "^1.0.0" jupyter = "^1.0.0"
motor = "^3.3.2"
mypy = "0.991" mypy = "0.991"
pre-commit = "3.2.0" pre-commit = "3.2.0"
pylint = "2.15.10" pylint = "2.15.10"
......
...@@ -66,6 +66,12 @@ class MockMongoCollection: ...@@ -66,6 +66,12 @@ class MockMongoCollection:
insert_result.inserted_ids = inserted_ids insert_result.inserted_ids = inserted_ids
return insert_result return insert_result
def bulk_write(self, operations: List[Any]) -> Any:
for operation in operations:
obj = operation._doc["$set"]
_id = obj.pop("_id")
self.insert_one(obj, _id)
class MockMongoDB: class MockMongoDB:
def __init__(self) -> None: def __init__(self) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment