diff --git a/llama_index/storage/docstore/postgres_docstore.py b/llama_index/storage/docstore/postgres_docstore.py new file mode 100644 index 0000000000000000000000000000000000000000..24e55f55058c58be83922f3b35861f4c7ec8de3d --- /dev/null +++ b/llama_index/storage/docstore/postgres_docstore.py @@ -0,0 +1,78 @@ +from typing import Optional + +from llama_index.storage.docstore.keyval_docstore import KVDocumentStore +from llama_index.storage.docstore.types import DEFAULT_BATCH_SIZE +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore + + +class PostgresDocumentStore(KVDocumentStore): + """Mongo Document (Node) store. + + A MongoDB store for Document and Node objects. + + Args: + mongo_kvstore (MongoDBKVStore): MongoDB key-value store + namespace (str): namespace for the docstore + + """ + + def __init__( + self, + postgres_kvstore: PostgresKVStore, + namespace: Optional[str] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + """Init a PostgresDocumentStore.""" + super().__init__(postgres_kvstore, namespace=namespace, batch_size=batch_size) + + @classmethod + def from_uri( + cls, + uri: str, + namespace: Optional[str] = None, + table_name: str = "docstore", + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresDocumentStore": + """Load a PostgresDocumentStore from a Postgres URI.""" + postgres_kvstore = PostgresKVStore.from_uri( + uri=uri, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + return cls(postgres_kvstore, namespace) + + @classmethod + def from_params( + cls, + host: Optional[str] = None, + port: Optional[str] = None, + database: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + namespace: Optional[str] = None, + table_name: str = "docstore", + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresDocumentStore": + """Load a PostgresDocumentStore from a Postgres host and port.""" + postgres_kvstore = PostgresKVStore.from_params( + host=host, + port=port, + database=database, + user=user, + password=password, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + return cls(postgres_kvstore, namespace) diff --git a/llama_index/storage/index_store/postgres_index_store.py b/llama_index/storage/index_store/postgres_index_store.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a4cebc433140becd523e023b1bc420fc98747e --- /dev/null +++ b/llama_index/storage/index_store/postgres_index_store.py @@ -0,0 +1,74 @@ +from typing import Optional + +from llama_index.storage.index_store.keyval_index_store import KVIndexStore +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore + + +class PostgresIndexStore(KVIndexStore): + """Mongo Index store. + + Args: + mongo_kvstore (MongoDBKVStore): MongoDB key-value store + namespace (str): namespace for the index store + + """ + + def __init__( + self, + postgres_kvstore: PostgresKVStore, + namespace: Optional[str] = None, + ) -> None: + """Init a MongoIndexStore.""" + super().__init__(postgres_kvstore, namespace=namespace) + + @classmethod + def from_uri( + cls, + uri: str, + namespace: Optional[str] = None, + table_name: str = "indexstore", + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresIndexStore": + """Load a PostgresIndexStore from a PostgresURI.""" + postgres_kvstore = PostgresKVStore.from_uri( + uri=uri, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + return cls(postgres_kvstore, namespace) + + @classmethod + def from_params( + cls, + host: Optional[str] = None, + port: Optional[str] = None, + database: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + namespace: Optional[str] = None, + table_name: str = "indexstore", + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresIndexStore": + """Load a PostgresIndexStore from a Postgres host and port.""" + postgres_kvstore = PostgresKVStore.from_params( + host=host, + port=port, + database=database, + user=user, + password=password, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + return cls(postgres_kvstore, namespace) diff --git a/llama_index/storage/kvstore/postgres_kvstore.py b/llama_index/storage/kvstore/postgres_kvstore.py new file mode 100644 index 0000000000000000000000000000000000000000..e38108c72ad557f6f65bc0b8e5ea1debc9e77be7 --- /dev/null +++ b/llama_index/storage/kvstore/postgres_kvstore.py @@ -0,0 +1,452 @@ +import json +from typing import Any, Dict, List, Optional, Tuple, Type + +from llama_index.storage.kvstore.types import ( + DEFAULT_BATCH_SIZE, + DEFAULT_COLLECTION, + BaseKVStore, +) + +IMPORT_ERROR_MSG = "`asyncpg` package not found, please run `pip install asyncpg`" + + +def get_data_model( + base: Type, + index_name: str, + schema_name: str, + use_jsonb: bool = False, +) -> Any: + """ + This part create a dynamic sqlalchemy model with a new table. + """ + from sqlalchemy import Column, Index, Integer, UniqueConstraint + from sqlalchemy.dialects.postgresql import JSON, JSONB, VARCHAR + + tablename = "data_%s" % index_name # dynamic table name + class_name = "Data%s" % index_name # dynamic class name + + metadata_dtype = JSONB if use_jsonb else JSON + + class AbstractData(base): # type: ignore + __abstract__ = True # this line is necessary + id = Column(Integer, primary_key=True, autoincrement=True) + key = Column(VARCHAR, nullable=False) + namespace = Column(VARCHAR, nullable=False) + value = Column(metadata_dtype) + + return type( + class_name, + (AbstractData,), + { + "__tablename__": tablename, + "__table_args__": ( + UniqueConstraint( + "key", "namespace", name=f"{tablename}:unique_key_namespace" + ), + Index(f"{tablename}:idx_key_namespace", "key", "namespace"), + {"schema": schema_name}, + ), + }, + ) + + +class PostgresKVStore(BaseKVStore): + """Postgres Key-Value store. + + Args: + mongo_client (Any): MongoDB client + uri (Optional[str]): MongoDB URI + host (Optional[str]): MongoDB host + port (Optional[int]): MongoDB port + db_name (Optional[str]): MongoDB database name + """ + + connection_string: str + async_connection_string: str + table_name: str + schema_name: str + perform_setup: bool + debug: bool + use_jsonb: bool + + def __init__( + self, + connection_string: str, + async_connection_string: str, + table_name: str, + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> None: + try: + import asyncpg # noqa + import psycopg2 # noqa + import sqlalchemy + import sqlalchemy.ext.asyncio # noqa + except ImportError: + raise ImportError( + "`sqlalchemy[asyncio]`, `psycopg2-binary` and `asyncpg` " + "packages should be pre installed" + ) + + table_name = table_name.lower() + schema_name = schema_name.lower() + self.connection_string = connection_string + self.async_connection_string = async_connection_string + self.table_name = table_name + self.schema_name = schema_name + self.perform_setup = perform_setup + self.debug = debug + self.use_jsonb = use_jsonb + self._is_initialized = False + + from sqlalchemy.orm import declarative_base + + # sqlalchemy model + self._base = declarative_base() + self._table_class = get_data_model( + self._base, + table_name, + schema_name, + use_jsonb=use_jsonb, + ) + + @classmethod + def from_params( + cls, + host: Optional[str] = None, + port: Optional[str] = None, + database: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + table_name: str = "kvstore", + schema_name: str = "public", + connection_string: Optional[str] = None, + async_connection_string: Optional[str] = None, + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresKVStore": + """Return connection string from database parameters.""" + conn_str = ( + connection_string + or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" + ) + async_conn_str = async_connection_string or ( + f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" + ) + return cls( + connection_string=conn_str, + async_connection_string=async_conn_str, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + + @classmethod + def from_uri( + cls, + uri: str, + table_name: str = "kvstore", + schema_name: str = "public", + perform_setup: bool = True, + debug: bool = False, + use_jsonb: bool = False, + ) -> "PostgresKVStore": + """Return connection string from database parameters.""" + from sqlalchemy.engine.url import URL + + url = URL.create("postgresql+psycopg2", uri) + return cls.from_params( + host=url.host, + port=url.port, + database=url.database, + user=url.username, + password=url.password, + table_name=table_name, + schema_name=schema_name, + perform_setup=perform_setup, + debug=debug, + use_jsonb=use_jsonb, + ) + + def _connect(self) -> Any: + from sqlalchemy import create_engine + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + from sqlalchemy.orm import sessionmaker + + self._engine = create_engine(self.connection_string, echo=self.debug) + self._session = sessionmaker(self._engine) + + self._async_engine = create_async_engine(self.async_connection_string) + self._async_session = sessionmaker(self._async_engine, class_=AsyncSession) + + def _create_schema_if_not_exists(self) -> None: + with self._session() as session, session.begin(): + from sqlalchemy import text + + # Check if the specified schema exists with "CREATE" statement + check_schema_statement = text( + f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{self.schema_name}'" + ) + result = session.execute(check_schema_statement).fetchone() + + # If the schema does not exist, then create it + if not result: + create_schema_statement = text( + f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}" + ) + session.execute(create_schema_statement) + + session.commit() + + def _create_tables_if_not_exists(self) -> None: + with self._session() as session, session.begin(): + self._base.metadata.create_all(session.connection()) + + def _initialize(self) -> None: + if not self._is_initialized: + self._connect() + if self.perform_setup: + self._create_schema_if_not_exists() + self._create_tables_if_not_exists() + self._is_initialized = True + + def put( + self, + key: str, + val: dict, + collection: str = DEFAULT_COLLECTION, + ) -> None: + """Put a key-value pair into the store. + + Args: + key (str): key + val (dict): value + collection (str): collection name + + """ + self.put_all([(key, val)], collection=collection) + + async def aput( + self, + key: str, + val: dict, + collection: str = DEFAULT_COLLECTION, + ) -> None: + """Put a key-value pair into the store. + + Args: + key (str): key + val (dict): value + collection (str): collection name + + """ + await self.aput_all([(key, val)], collection=collection) + + def put_all( + self, + kv_pairs: List[Tuple[str, dict]], + collection: str = DEFAULT_COLLECTION, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + from sqlalchemy import text + + self._initialize() + with self._session() as session: + for i in range(0, len(kv_pairs), batch_size): + batch = kv_pairs[i : i + batch_size] + + # Prepare the VALUES part of the SQL statement + values_clause = ", ".join( + f"(:key_{i}, :namespace_{i}, :value_{i})" + for i, _ in enumerate(batch) + ) + + # Prepare the raw SQL for bulk upsert + # Note: This SQL is PostgreSQL-specific. Adjust for other databases. + stmt = text( + f""" + INSERT INTO {self.schema_name}.{self._table_class.__tablename__} (key, namespace, value) + VALUES {values_clause} + ON CONFLICT (key, namespace) + DO UPDATE SET + value = EXCLUDED.value; + """ + ) + + # Flatten the list of tuples for execute parameters + params = {} + for i, (key, value) in enumerate(batch): + params[f"key_{i}"] = key + params[f"namespace_{i}"] = collection + params[f"value_{i}"] = json.dumps(value) + + # Execute the bulk upsert + session.execute(stmt, params) + session.commit() + + async def aput_all( + self, + kv_pairs: List[Tuple[str, dict]], + collection: str = DEFAULT_COLLECTION, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + from sqlalchemy import text + + self._initialize() + async with self._async_session() as session: + for i in range(0, len(kv_pairs), batch_size): + batch = kv_pairs[i : i + batch_size] + + # Prepare the VALUES part of the SQL statement + values_clause = ", ".join( + f"(:key_{i}, :namespace_{i}, :value_{i})" + for i, _ in enumerate(batch) + ) + + # Prepare the raw SQL for bulk upsert + # Note: This SQL is PostgreSQL-specific. Adjust for other databases. + stmt = text( + f""" + INSERT INTO {self.schema_name}.{self._table_class.__tablename__} (key, namespace, value) + VALUES {values_clause} + ON CONFLICT (key, namespace) + DO UPDATE SET + value = EXCLUDED.value; + """ + ) + + # Flatten the list of tuples for execute parameters + params = {} + for i, (key, value) in enumerate(batch): + params[f"key_{i}"] = key + params[f"namespace_{i}"] = collection + params[f"value_{i}"] = json.dumps(value) + + # Execute the bulk upsert + await session.execute(stmt, params) + await session.commit() + + def get(self, key: str, collection: str = DEFAULT_COLLECTION) -> Optional[dict]: + """Get a value from the store. + + Args: + key (str): key + collection (str): collection name + + """ + from sqlalchemy import select + + self._initialize() + with self._session() as session: + result = session.execute( + select(self._table_class) + .filter_by(key=key) + .filter_by(namespace=collection) + ) + result = result.scalars().first() + if result: + return result.value + return None + + async def aget( + self, key: str, collection: str = DEFAULT_COLLECTION + ) -> Optional[dict]: + """Get a value from the store. + + Args: + key (str): key + collection (str): collection name + + """ + from sqlalchemy import select + + self._initialize() + async with self._async_session() as session: + result = await session.execute( + select(self._table_class) + .filter_by(key=key) + .filter_by(namespace=collection) + ) + result = result.scalars().first() + if result: + return result.value + return None + + def get_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]: + """Get all values from the store. + + Args: + collection (str): collection name + + """ + from sqlalchemy import select + + self._initialize() + with self._session() as session: + results = session.execute( + select(self._table_class).filter_by(namespace=collection) + ) + results = results.scalars().all() + return {result.key: result.value for result in results} if results else {} + + async def aget_all(self, collection: str = DEFAULT_COLLECTION) -> Dict[str, dict]: + """Get all values from the store. + + Args: + collection (str): collection name + + """ + from sqlalchemy import select + + self._initialize() + async with self._async_session() as session: + results = await session.execute( + select(self._table_class).filter_by(namespace=collection) + ) + results = results.scalars().all() + return {result.key: result.value for result in results} if results else {} + + def delete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool: + """Delete a value from the store. + + Args: + key (str): key + collection (str): collection name + + """ + from sqlalchemy import delete + + self._initialize() + with self._session() as session: + result = session.execute( + delete(self._table_class) + .filter_by(namespace=collection) + .filter_by(key=key) + ) + session.commit() + return result.rowcount > 0 + + async def adelete(self, key: str, collection: str = DEFAULT_COLLECTION) -> bool: + """Delete a value from the store. + + Args: + key (str): key + collection (str): collection name + + """ + from sqlalchemy import delete + + self._initialize() + async with self._async_session() as session: + async with session.begin(): + result = await session.execute( + delete(self._table_class) + .filter_by(namespace=collection) + .filter_by(key=key) + ) + return result.rowcount > 0 diff --git a/poetry.lock b/poetry.lock index 3d0e8e727c3ce3e252b20280a4e00e90e5db55cb..be3941e6a5e4fcb14c6573974292cc3420a6f016 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1230,6 +1230,27 @@ idna = ["idna (>=2.1)"] trio = ["trio (>=0.14)"] wmi = ["wmi (>=1.5.1)"] +[[package]] +name = "docker" +version = "7.0.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.0.0-py3-none-any.whl", hash = "sha256:12ba681f2777a0ad28ffbcc846a69c31b4dfd9752b47eb425a274ee269c5e14b"}, + {file = "docker-7.0.0.tar.gz", hash = "sha256:323736fb92cd9418fc5e7133bc953e11a9da04f4483f828b527db553f1e7e5a3"}, +] + +[package.dependencies] +packaging = ">=14.0" +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "docutils" version = "0.16" @@ -7790,4 +7811,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "6d11f5b95418266365cd690eaa14e59456e23c056c64157698f67fb333bba0b3" +content-hash = "1ccf014ec22186ffcbec1325d4876089a318f0bec886beadd05b39a145e6a86a" diff --git a/pyproject.toml b/pyproject.toml index 576571900a17096e635aa0fcc29826ad305630e1..a61008143c611516ef8cfe73ea21d44c43ae3787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} boto3 = "1.33.6" # needed for tests botocore = ">=1.33.13" codespell = {extras = ["toml"], version = ">=v2.2.6"} +docker = "^7.0.0" google-ai-generativelanguage = {python = ">=3.9,<3.12", version = "^0.4.0"} ipython = "8.10.0" jupyter = "^1.0.0" diff --git a/tests/storage/conftest.py b/tests/storage/conftest.py index 39e7c10777bda26c7c18d46588159807df8507b2..d2efafd564f22be64d3571db7e15d475ae8fb386 100644 --- a/tests/storage/conftest.py +++ b/tests/storage/conftest.py @@ -1,6 +1,12 @@ +import time +from typing import Any, Dict, Generator + +import docker import pytest +from docker.models.containers import Container from llama_index.storage.kvstore.firestore_kvstore import FirestoreKVStore from llama_index.storage.kvstore.mongodb_kvstore import MongoDBKVStore +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore from llama_index.storage.kvstore.redis_kvstore import RedisKVStore from llama_index.storage.kvstore.simple_kvstore import SimpleKVStore @@ -36,3 +42,67 @@ def redis_kvstore() -> "RedisKVStore": except ImportError: return RedisKVStore(redis_client=None, redis_url="redis://127.0.0.1:6379") return RedisKVStore(redis_client=client) + + +@pytest.fixture(scope="module") +def postgres_container() -> Generator[Dict[str, Any], None, None]: + # Define PostgreSQL settings + postgres_image = "postgres:latest" + postgres_env = { + "POSTGRES_DB": "testdb", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpassword", + } + postgres_ports = {"5432/tcp": 5432} + container = None + try: + # Initialize Docker client + client = docker.from_env() + + # Run PostgreSQL container + container = client.containers.run( + postgres_image, environment=postgres_env, ports=postgres_ports, detach=True + ) + + # Retrieve the container's port + container.reload() + postgres_port = container.attrs["NetworkSettings"]["Ports"]["5432/tcp"][0][ + "HostPort" + ] + + # Wait for PostgreSQL to start + time.sleep(10) # Adjust the sleep time if necessary + + # Return connection information + yield { + "container": container, + "connection_string": f"postgresql://testuser:testpassword@0.0.0.0:5432/testdb", + "async_connection_string": f"postgresql+asyncpg://testuser:testpassword@0.0.0.0:5432/testdb", + } + finally: + # Stop and remove the container + if container: + container.stop() + container.remove() + client.close() + + +@pytest.fixture() +def postgres_kvstore( + postgres_container: Container, +) -> Generator[PostgresKVStore, None, None]: + kvstore = None + try: + kvstore = PostgresKVStore( + connection_string=postgres_container["connection_string"], + async_connection_string=postgres_container["async_connection_string"], + table_name="test_kvstore", + schema_name="test_schema", + use_jsonb=True, + ) + yield kvstore + finally: + if kvstore: + keys = kvstore.get_all().keys() + for key in keys: + kvstore.delete(key) diff --git a/tests/storage/docstore/test_postgres_docstore.py b/tests/storage/docstore/test_postgres_docstore.py new file mode 100644 index 0000000000000000000000000000000000000000..72f8c5d518ff6028388dcec80937d958eb176a41 --- /dev/null +++ b/tests/storage/docstore/test_postgres_docstore.py @@ -0,0 +1,80 @@ +from typing import List + +import pytest +from llama_index.schema import BaseNode, Document +from llama_index.storage.docstore.postgres_docstore import PostgresDocumentStore +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore + +try: + import asyncpg # noqa + import psycopg2 # noqa + import sqlalchemy # noqa + + no_packages = False +except ImportError: + no_packages = True + + +@pytest.fixture() +def documents() -> List[Document]: + return [ + Document(text="doc_1"), + Document(text="doc_2"), + ] + + +@pytest.fixture() +def postgres_docstore(postgres_kvstore: PostgresKVStore) -> PostgresDocumentStore: + return PostgresDocumentStore(postgres_kvstore=postgres_kvstore) + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_postgres_docstore( + postgres_docstore: PostgresDocumentStore, documents: List[Document] +) -> None: + ds = postgres_docstore + assert len(ds.docs) == 0 + + # test adding documents + ds.add_documents(documents) + assert len(ds.docs) == 2 + assert all(isinstance(doc, BaseNode) for doc in ds.docs.values()) + + # test updating documents + ds.add_documents(documents) + print(ds.docs) + assert len(ds.docs) == 2 + + # test getting documents + doc0 = ds.get_document(documents[0].get_doc_id()) + assert doc0 is not None + assert documents[0].get_content() == doc0.get_content() + + # test deleting documents + ds.delete_document(documents[0].get_doc_id()) + assert len(ds.docs) == 1 + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_postgres_docstore_hash( + postgres_docstore: PostgresDocumentStore, documents: List[Document] +) -> None: + ds = postgres_docstore + + # Test setting hash + ds.set_document_hash("test_doc_id", "test_doc_hash") + doc_hash = ds.get_document_hash("test_doc_id") + assert doc_hash == "test_doc_hash" + + # Test updating hash + ds.set_document_hash("test_doc_id", "test_doc_hash_new") + doc_hash = ds.get_document_hash("test_doc_id") + assert doc_hash == "test_doc_hash_new" + + # Test getting non-existent + doc_hash = ds.get_document_hash("test_not_exist") + assert doc_hash is None diff --git a/tests/storage/index_store/test_postgres_index_store.py b/tests/storage/index_store/test_postgres_index_store.py new file mode 100644 index 0000000000000000000000000000000000000000..743a4e7bce3191283d2b69fabd6686e2cd0ebb38 --- /dev/null +++ b/tests/storage/index_store/test_postgres_index_store.py @@ -0,0 +1,29 @@ +import pytest +from llama_index.data_structs.data_structs import IndexGraph +from llama_index.storage.index_store.postgres_index_store import PostgresIndexStore +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore + +try: + import asyncpg # noqa + import psycopg2 # noqa + import sqlalchemy # noqa + + no_packages = False +except ImportError: + no_packages = True + + +@pytest.fixture() +def postgres_indexstore(postgres_kvstore: PostgresKVStore) -> PostgresIndexStore: + return PostgresIndexStore(postgres_kvstore=postgres_kvstore) + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_postgres_index_store(postgres_indexstore: PostgresIndexStore) -> None: + index_struct = IndexGraph() + index_store = postgres_indexstore + + index_store.add_index_struct(index_struct) + assert index_store.get_index_struct(struct_id=index_struct.index_id) == index_struct diff --git a/tests/storage/kvstore/test_postgres_kvstore.py b/tests/storage/kvstore/test_postgres_kvstore.py new file mode 100644 index 0000000000000000000000000000000000000000..7862bd7ea321b713a59d8ed28f0940e70d389737 --- /dev/null +++ b/tests/storage/kvstore/test_postgres_kvstore.py @@ -0,0 +1,141 @@ +import pytest +from llama_index.storage.kvstore.postgres_kvstore import PostgresKVStore + +try: + import asyncpg # noqa + import psycopg2 # noqa + import sqlalchemy # noqa + + no_packages = False +except ImportError: + no_packages = True + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_kvstore_basic(postgres_kvstore: PostgresKVStore) -> None: + test_key = "test_key_basic" + test_blob = {"test_obj_key": "test_obj_val"} + postgres_kvstore.put(test_key, test_blob) + blob = postgres_kvstore.get(test_key) + assert blob == test_blob + + blob = postgres_kvstore.get(test_key, collection="non_existent") + assert blob is None + + deleted = postgres_kvstore.delete(test_key) + assert deleted + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +@pytest.mark.asyncio() +async def test_kvstore_async_basic(postgres_kvstore: PostgresKVStore) -> None: + test_key = "test_key_basic" + test_blob = {"test_obj_key": "test_obj_val"} + await postgres_kvstore.aput(test_key, test_blob) + blob = await postgres_kvstore.aget(test_key) + assert blob == test_blob + + blob = await postgres_kvstore.aget(test_key, collection="non_existent") + assert blob is None + + deleted = await postgres_kvstore.adelete(test_key) + assert deleted + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_kvstore_delete(postgres_kvstore: PostgresKVStore) -> None: + test_key = "test_key_delete" + test_blob = {"test_obj_key": "test_obj_val"} + postgres_kvstore.put(test_key, test_blob) + blob = postgres_kvstore.get(test_key) + assert blob == test_blob + + postgres_kvstore.delete(test_key) + blob = postgres_kvstore.get(test_key) + assert blob is None + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +@pytest.mark.asyncio() +async def test_kvstore_adelete(postgres_kvstore: PostgresKVStore) -> None: + test_key = "test_key_delete" + test_blob = {"test_obj_key": "test_obj_val"} + await postgres_kvstore.aput(test_key, test_blob) + blob = await postgres_kvstore.aget(test_key) + assert blob == test_blob + + await postgres_kvstore.adelete(test_key) + blob = await postgres_kvstore.aget(test_key) + assert blob is None + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +def test_kvstore_getall(postgres_kvstore: PostgresKVStore) -> None: + test_key_1 = "test_key_1" + test_blob_1 = {"test_obj_key": "test_obj_val"} + postgres_kvstore.put(test_key_1, test_blob_1) + blob = postgres_kvstore.get(test_key_1) + assert blob == test_blob_1 + test_key_2 = "test_key_2" + test_blob_2 = {"test_obj_key": "test_obj_val"} + postgres_kvstore.put(test_key_2, test_blob_2) + blob = postgres_kvstore.get(test_key_2) + assert blob == test_blob_2 + + blob = postgres_kvstore.get_all() + assert len(blob) == 2 + + postgres_kvstore.delete(test_key_1) + postgres_kvstore.delete(test_key_2) + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +@pytest.mark.asyncio() +async def test_kvstore_agetall(postgres_kvstore: PostgresKVStore) -> None: + test_key_1 = "test_key_1" + test_blob_1 = {"test_obj_key": "test_obj_val"} + await postgres_kvstore.aput(test_key_1, test_blob_1) + blob = await postgres_kvstore.aget(test_key_1) + assert blob == test_blob_1 + test_key_2 = "test_key_2" + test_blob_2 = {"test_obj_key": "test_obj_val"} + await postgres_kvstore.aput(test_key_2, test_blob_2) + blob = await postgres_kvstore.aget(test_key_2) + assert blob == test_blob_2 + + blob = await postgres_kvstore.aget_all() + assert len(blob) == 2 + + await postgres_kvstore.adelete(test_key_1) + await postgres_kvstore.adelete(test_key_2) + + +@pytest.mark.skipif( + no_packages, reason="ayncpg, pscopg2-binary and sqlalchemy not installed" +) +@pytest.mark.asyncio() +async def test_kvstore_putall(postgres_kvstore: PostgresKVStore) -> None: + test_key = "test_key_putall_1" + test_blob = {"test_obj_key": "test_obj_val"} + test_key2 = "test_key_putall_2" + test_blob2 = {"test_obj_key2": "test_obj_val2"} + await postgres_kvstore.aput_all([(test_key, test_blob), (test_key2, test_blob2)]) + blob = await postgres_kvstore.aget(test_key) + assert blob == test_blob + blob = await postgres_kvstore.aget(test_key2) + assert blob == test_blob2 + + await postgres_kvstore.adelete(test_key) + await postgres_kvstore.adelete(test_key2)