From 9d75ab49d59cc528030636105dd08fb66d2b3269 Mon Sep 17 00:00:00 2001
From: Adithya Krishnan <krishsandeep@gmail.com>
Date: Thu, 11 Apr 2024 03:07:42 +0200
Subject: [PATCH] Update embedding field to use fixed array size (#12416)

---
 .../llama_index/vector_stores/duckdb/base.py  | 44 +++++++++----------
 .../pyproject.toml                            |  4 +-
 .../tests/test_duckdb.py                      |  2 +-
 3 files changed, 24 insertions(+), 26 deletions(-)

diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
index 75822b276..790ac9483 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
@@ -98,7 +98,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
         database_name: Optional[str] = ":memory:",
         table_name: Optional[str] = "documents",
         # schema_name: Optional[str] = "main",
-        embed_dim: Optional[int] = 1536,
+        embed_dim: Optional[int] = None,
         # hybrid_search: Optional[bool] = False,
         # https://duckdb.org/docs/extensions/full_text_search
         text_search_config: Optional[dict] = {
@@ -161,13 +161,9 @@ class DuckDBVectorStore(BasePydanticVectorStore):
             except Exception as e:
                 raise ValueError(f"Index table {table_name} not found in the database.")
 
-            _std = {
-                "text": "VARCHAR",
-                "node_id": "VARCHAR",
-                "embedding": "FLOAT[]",
-                "metadata_": "JSON",
-            }
-            _ti = {_i[0]: _i[1] for _i in _table_info}
+            # Not testing for the column type similarity only testing for the column names.
+            _std = {"text", "node_id", "embedding", "metadata_"}
+            _ti = {_i[0] for _i in _table_info}
             if _std != _ti:
                 raise ValueError(
                     f"Index table {table_name} does not have the correct schema."
@@ -188,7 +184,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
         database_name: Optional[str] = ":memory:",
         table_name: Optional[str] = "documents",
         # schema_name: Optional[str] = "main",
-        embed_dim: Optional[int] = 1536,
+        embed_dim: Optional[int] = None,
         # hybrid_search: Optional[bool] = False,
         text_search_config: Optional[dict] = {
             "stemmer": "english",
@@ -226,9 +222,17 @@ class DuckDBVectorStore(BasePydanticVectorStore):
             # TODO: schema.table also.
             # Check if table and type is present
             # if not, create table
-            if self.database_name == ":memory:":
-                self._conn.execute(
-                    f"""
+            if self.embed_dim is None:
+                _query = f"""
+                    CREATE TABLE {self.table_name} (
+                        node_id VARCHAR,
+                        text TEXT,
+                        embedding FLOAT[],
+                        metadata_ JSON
+                        );
+                    """
+            else:
+                _query = f"""
                     CREATE TABLE {self.table_name} (
                         node_id VARCHAR,
                         text TEXT,
@@ -236,19 +240,13 @@ class DuckDBVectorStore(BasePydanticVectorStore):
                         metadata_ JSON
                         );
                     """
-                )
+
+            if self.database_name == ":memory:":
+                self._conn.execute(_query)
             else:
                 with DuckDBLocalContext(self._database_path) as _conn:
-                    _conn.execute(
-                        f"""
-                        CREATE TABLE {self.table_name} (
-                            node_id VARCHAR,
-                            text TEXT,
-                            embedding FLOAT[{self.embed_dim}],
-                            metadata_ JSON
-                            );
-                        """
-                    )
+                    _conn.execute(_query)
+
             self._is_initialized = True
 
     def _node_to_table_row(self, node: BaseNode) -> Any:
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
index 2116667f7..a41553cf0 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
@@ -28,12 +28,12 @@ license = "MIT"
 maintainers = ["krish-adi"]
 name = "llama-index-vector-stores-duckdb"
 readme = "README.md"
-version = "0.1.3"
+version = "0.1.4"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.0"
-duckdb = "0.9.2"
+duckdb = "^0.10.1"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/tests/test_duckdb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/tests/test_duckdb.py
index c9e226bdc..7f295bd8d 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/tests/test_duckdb.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/tests/test_duckdb.py
@@ -82,7 +82,7 @@ def text_node_list() -> List[TextNode]:
 
 @pytest.fixture(scope="module")
 def vector_store() -> DuckDBVectorStore:
-    return DuckDBVectorStore()
+    return DuckDBVectorStore(embed_dim=3)
 
 
 def test_instance_creation_from_memory(
-- 
GitLab