diff --git a/.vscode/settings.json b/.vscode/settings.json index 0cee07a33df12823001844116fdd836dc1f587b1..598a079bcd4d20f990296e3c88becc5af13bc09a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,9 +2,12 @@ "python.formatting.provider": "none", "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true, + "source.organizeImports": true }, "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e2565242c379b02036e1c20dfd605c1e15c1e2..3ec01d1b2e2a57ec7c7da73598074ca4209766bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ### Bug Fixes / Nits - Normalize scores returned from ElasticSearch vector store (#7792) - Fixed `refresh_ref_docs()` bug with order of operations (#7664) +- Delay postgresql connection for `PGVectorStore` until actually needed (#7793) ## [0.8.33] - 2023-09-25 diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index 3873afaa69f4efcde1c4fdd3c73c4c46bc0b6329..4b2721d497dd638ec05afc71d2420cbebb173432 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -104,6 +104,7 @@ class PGVectorStore(BasePydanticVectorStore): _session: Any = PrivateAttr() _async_engine: Any = PrivateAttr() _async_session: Any = PrivateAttr() + _is_initialized: bool = PrivateAttr(default=False) def __init__( self, @@ -157,11 +158,10 @@ class PGVectorStore(BasePydanticVectorStore): debug=debug, ) - self._connect() - self._create_extension() - self._create_tables_if_not_exists() - async def close(self) -> None: + if not self._is_initialized: + return None + self._session.close_all() self._engine.dispose() @@ -207,6 +207,8 @@ class PGVectorStore(BasePydanticVectorStore): @property def client(self) -> Any: + if not self._is_initialized: + return None return self._engine def _connect(self) -> Any: @@ -235,6 +237,13 @@ class PGVectorStore(BasePydanticVectorStore): session.execute(statement) session.commit() + def _initialize(self) -> None: + if not self._is_initialized: + self._connect() + self._create_extension() + self._create_tables_if_not_exists() + self._is_initialized = True + def _node_to_table_row(self, node: BaseNode) -> Any: return self._table_class( node_id=node.node_id, @@ -248,6 +257,7 @@ class PGVectorStore(BasePydanticVectorStore): ) def add(self, nodes: List[BaseNode]) -> List[str]: + self._initialize() ids = [] with self._session() as session: with session.begin(): @@ -259,6 +269,7 @@ class PGVectorStore(BasePydanticVectorStore): return ids async def async_add(self, nodes: List[BaseNode]) -> List[str]: + self._initialize() ids = [] async with self._async_session() as session: async with session.begin(): @@ -480,6 +491,7 @@ class PGVectorStore(BasePydanticVectorStore): async def aquery( self, query: VectorStoreQuery, **kwargs: Any ) -> VectorStoreQueryResult: + self._initialize() if query.mode == VectorStoreQueryMode.HYBRID: results = await self._async_hybrid_query(query) elif query.mode in [ @@ -500,6 +512,7 @@ class PGVectorStore(BasePydanticVectorStore): return self._db_rows_to_query_result(results) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + self._initialize() if query.mode == VectorStoreQueryMode.HYBRID: results = self._hybrid_query(query) elif query.mode in [ @@ -522,6 +535,7 @@ class PGVectorStore(BasePydanticVectorStore): def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: import sqlalchemy + self._initialize() with self._session() as session: with session.begin(): stmt = sqlalchemy.text( diff --git a/tests/vector_stores/test_postgres.py b/tests/vector_stores/test_postgres.py index 9c231b721266ea47cde91f969619372b1257c325..ef5ef584941030102dc5cf7cddca471d31c4ebdb 100644 --- a/tests/vector_stores/test_postgres.py +++ b/tests/vector_stores/test_postgres.py @@ -159,6 +159,8 @@ async def test_instance_creation(db: None) -> None: table_name=TEST_TABLE_NAME, ) assert isinstance(pg, PGVectorStore) + assert not hasattr(pg, "_engine") + assert pg.client is None await pg.close() @@ -173,6 +175,7 @@ async def test_add_to_db_and_query( else: pg.add(node_embeddings) assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") q = VectorStoreQuery(query_embedding=_get_sample_vector(1.0), similarity_top_k=1) if use_async: res = await pg.aquery(q) @@ -194,6 +197,7 @@ async def test_add_to_db_and_query_with_metadata_filters( else: pg.add(node_embeddings) assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") filters = MetadataFilters( filters=[ExactMatchFilter(key="test_key", value="test_value")] ) @@ -220,6 +224,7 @@ async def test_add_to_db_query_and_delete( else: pg.add(node_embeddings) assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") q = VectorStoreQuery(query_embedding=_get_sample_vector(0.1), similarity_top_k=1) @@ -243,6 +248,7 @@ async def test_save_load( else: pg.add(node_embeddings) assert isinstance(pg, PGVectorStore) + assert hasattr(pg, "_engine") q = VectorStoreQuery(query_embedding=_get_sample_vector(0.1), similarity_top_k=1) @@ -258,6 +264,7 @@ async def test_save_load( await pg.close() loaded_pg = cast(PGVectorStore, load_vector_store(pg_dict)) + assert not hasattr(loaded_pg, "_engine") loaded_pg_dict = loaded_pg.to_dict() for key, val in pg.to_dict().items(): assert loaded_pg_dict[key] == val @@ -266,6 +273,7 @@ async def test_save_load( res = await loaded_pg.aquery(q) else: res = loaded_pg.query(q) + assert hasattr(loaded_pg, "_engine") assert res.nodes assert len(res.nodes) == 1 assert res.nodes[0].node_id == "bbb" @@ -286,6 +294,7 @@ async def test_sparse_query( else: pg_hybrid.add(hybrid_node_embeddings) assert isinstance(pg_hybrid, PGVectorStore) + assert hasattr(pg_hybrid, "_engine") # text search should work when query is a sentence and not just a single word q = VectorStoreQuery( @@ -318,6 +327,7 @@ async def test_hybrid_query( else: pg_hybrid.add(hybrid_node_embeddings) assert isinstance(pg_hybrid, PGVectorStore) + assert hasattr(pg_hybrid, "_engine") q = VectorStoreQuery( query_embedding=_get_sample_vector(0.1), @@ -389,6 +399,7 @@ async def test_add_to_db_and_hybrid_query_with_metadata_filters( else: pg_hybrid.add(hybrid_node_embeddings) assert isinstance(pg_hybrid, PGVectorStore) + assert hasattr(pg_hybrid, "_engine") filters = MetadataFilters( filters=[ExactMatchFilter(key="test_key", value="test_value")] )