diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/README.md b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/README.md index c916ff8c6232b5454e8dfc895b0fb58d9c59ef9a..36737f6a3f4bd6f70c90f30c714536cf7ce1fe10 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/README.md +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/README.md @@ -27,6 +27,8 @@ vector_store = MariaDBVectorStore.from_params( database="vectordb", table_name="llama_index_vectorstore", embed_dim=1536, # OpenAI embedding dimension + default_m=6, # MariaDB Vector system parameter + ef_search=20, # MariaDB Vector system parameter ) ``` diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py index 5394ad1c9e1df6ad5e6597b7852476e76f7d95ec..c755aa9f491e43bd0b1fb99f8a6ddcd001ca221b 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, NamedTuple, Optional, Union from urllib.parse import quote_plus import sqlalchemy - from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode from llama_index.core.vector_stores.types import ( @@ -52,6 +51,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): password="password", database="vectordb", table_name="llama_index_vectorstore", + default_m=6, + ef_search=20, embed_dim=1536 # OpenAI embedding dimension ) ``` @@ -65,6 +66,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): table_name: str schema_name: str embed_dim: int + default_m: int + ef_search: int perform_setup: bool debug: bool @@ -78,6 +81,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): table_name: str, schema_name: str, embed_dim: int = 1536, + default_m: int = 6, + ef_search: int = 20, perform_setup: bool = True, debug: bool = False, ) -> None: @@ -89,6 +94,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): table_name (str): Table name. schema_name (str): Schema name. embed_dim (int, optional): Embedding dimensions. Defaults to 1536. + default_m (int, optional): Default M value for the vector index. Defaults to 6. + ef_search (int, optional): EF search value for the vector index. Defaults to 20. perform_setup (bool, optional): If DB should be set up. Defaults to True. debug (bool, optional): Debug mode. Defaults to False. """ @@ -98,15 +105,20 @@ class MariaDBVectorStore(BasePydanticVectorStore): table_name=table_name, schema_name=schema_name, embed_dim=embed_dim, + default_m=default_m, + ef_search=ef_search, perform_setup=perform_setup, debug=debug, ) + self._initialize() + def close(self) -> None: if not self._is_initialized: return self._engine.dispose() + self._is_initialized = False @classmethod def class_name(cls) -> str: @@ -125,6 +137,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): connection_string: Optional[Union[str, sqlalchemy.engine.URL]] = None, connection_args: Optional[Dict[str, Any]] = None, embed_dim: int = 1536, + default_m: int = 6, + ef_search: int = 20, perform_setup: bool = True, debug: bool = False, ) -> "MariaDBVectorStore": @@ -141,6 +155,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): connection_string (Union[str, sqlalchemy.engine.URL]): Connection string to MariaDB DB. connection_args (Dict[str, Any], optional): A dictionary of connection options. embed_dim (int, optional): Embedding dimensions. Defaults to 1536. + default_m (int, optional): Default M value for the vector index. Defaults to 6. + ef_search (int, optional): EF search value for the vector index. Defaults to 20. perform_setup (bool, optional): If DB should be set up. Defaults to True. debug (bool, optional): Debug mode. Defaults to False. @@ -162,6 +178,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): table_name=table_name, schema_name=schema_name, embed_dim=embed_dim, + default_m=default_m, + ef_search=ef_search, perform_setup=perform_setup, debug=debug, ) @@ -200,8 +218,8 @@ class MariaDBVectorStore(BasePydanticVectorStore): text TEXT, metadata JSON, embedding VECTOR({self.embed_dim}) NOT NULL, - INDEX `{self.table_name}_node_id_idx` (`node_id`), - VECTOR INDEX (embedding) DISTANCE=cosine + INDEX (`node_id`), + VECTOR INDEX (embedding) M={self.default_m} DISTANCE=cosine ) """ connection.execute(sqlalchemy.text(stmt)) @@ -378,6 +396,7 @@ class MariaDBVectorStore(BasePydanticVectorStore): self._initialize() stmt = f""" + SET STATEMENT mhnsw_ef_search={self.ef_search} FOR SELECT node_id, text, @@ -435,6 +454,26 @@ class MariaDBVectorStore(BasePydanticVectorStore): connection.commit() + def count(self) -> int: + self._initialize() + + with self._engine.connect() as connection: + stmt = f"""SELECT COUNT(*) FROM `{self.table_name}`""" + result = connection.execute(sqlalchemy.text(stmt)) + + return result.scalar() or 0 + + def drop(self) -> None: + self._initialize() + + with self._engine.connect() as connection: + stmt = f"""DROP TABLE IF EXISTS `{self.table_name}`""" + connection.execute(sqlalchemy.text(stmt)) + + connection.commit() + + self.close() + def clear(self) -> None: self._initialize() diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml index 94443f379bba9420ac31ea72bd978a0651590c4c..a70c05a39eb6e6fcc1aaf0df8849f5712929974d 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml @@ -21,13 +21,13 @@ ignore_missing_imports = true python_version = "3.8" [tool.poetry] -authors = ["Your Name <you@example.com>"] +authors = ["Kalin Arsov <kalin@skysql.com>", "Vishal Rao <vishal@skysql.com>"] description = "llama-index vector_stores mariadb integration" exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-mariadb" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/BUILD b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/BUILD index dabf212d7e7162849c24a733909ac4f645d75a31..536d32cc95c39541b24aeefd454c36ca1aac65fd 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/BUILD +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/BUILD @@ -1 +1,3 @@ -python_tests() +python_tests( + dependencies=["llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb:poetry#pymysql"] +) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py index 4dc9f1e2b1a8e03ef9b046d8be61a20f1831e589..3e7f2a9f34af00deea866593f5ec1dbdeefc6360 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py @@ -4,7 +4,6 @@ from typing import Generator, List import pytest import sqlalchemy - from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode from llama_index.core.vector_stores.types import ( FilterCondition, @@ -13,6 +12,7 @@ from llama_index.core.vector_stores.types import ( MetadataFilters, VectorStoreQuery, ) + from llama_index.vector_stores.mariadb import MariaDBVectorStore from llama_index.vector_stores.mariadb.base import _meets_min_server_version @@ -49,18 +49,19 @@ TEST_NODES: List[TextNode] = [ ), ] -vector_store = MariaDBVectorStore.from_params( - database="test", - table_name="vector_store_test", - embed_dim=3, - host="127.0.0.1", - user="root", - password="test", - port="3306", -) - +vector_store = None try: + vector_store = MariaDBVectorStore.from_params( + database="test", + table_name="vector_store_test", + embed_dim=3, + host="127.0.0.1", + user="root", + password="test", + port="3306", + ) + # If you want to run the integration tests you need to do: # docker-compose up @@ -84,7 +85,8 @@ def teardown(request: pytest.FixtureRequest) -> Generator: if "noautousefixtures" in request.keywords: return - vector_store.clear() + if vector_store is not None: + vector_store.clear() @pytest.fixture(scope="session", autouse=True) @@ -95,7 +97,8 @@ def close_db_connection(request: pytest.FixtureRequest) -> Generator: if "noautousefixtures" in request.keywords: return - vector_store.close() + if vector_store is not None: + vector_store.close() @pytest.mark.parametrize( @@ -117,7 +120,7 @@ def test_meets_min_server_version(version: str, supported: bool) -> None: @pytest.mark.skipif( - run_integration_tests is False, + not run_integration_tests, reason="MariaDB instance required for integration tests", ) def test_query() -> None: @@ -131,7 +134,7 @@ def test_query() -> None: @pytest.mark.skipif( - run_integration_tests is False, + not run_integration_tests, reason="MariaDB instance required for integration tests", ) def test_query_with_metadatafilters() -> None: @@ -168,7 +171,7 @@ def test_query_with_metadatafilters() -> None: @pytest.mark.skipif( - run_integration_tests is False, + not run_integration_tests, reason="MariaDB instance required for integration tests", ) def test_delete() -> None: @@ -188,7 +191,7 @@ def test_delete() -> None: @pytest.mark.skipif( - run_integration_tests is False, + not run_integration_tests, reason="MariaDB instance required for integration tests", ) def test_delete_nodes() -> None: @@ -212,7 +215,26 @@ def test_delete_nodes() -> None: @pytest.mark.skipif( - run_integration_tests is False, + not run_integration_tests, + reason="MariaDB instance required for integration tests", +) +def test_count() -> None: + vector_store.add(TEST_NODES) + assert vector_store.count() == 3 + + +@pytest.mark.skipif( + not run_integration_tests, + reason="MariaDB instance required for integration tests", +) +def test_drop() -> None: + vector_store.add(TEST_NODES) + vector_store.drop() + assert vector_store.count() == 0 + + +@pytest.mark.skipif( + not run_integration_tests, reason="MariaDB instance required for integration tests", ) def test_clear() -> None: