From 369a2942df2efcf6b74461c45d20a0af1fbe4ae2 Mon Sep 17 00:00:00 2001
From: Massimiliano Pippi <mpippi@gmail.com>
Date: Thu, 27 Feb 2025 22:13:51 +0100
Subject: [PATCH] fix: escape params in SQL queries in DuckDB vector store
 (#17952)

---
 .../llama_index/vector_stores/duckdb/base.py  | 103 ++++++++++--------
 .../pyproject.toml                            |   4 +-
 2 files changed, 58 insertions(+), 49 deletions(-)

diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
index 5ce76c34cf..9bbc8cbcd3 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py
@@ -1,13 +1,16 @@
 """DuckDB vector store."""
 
-import logging
 import json
-from typing import Any, List, Optional
+import logging
 import os
+from typing import Any, List, Optional, cast
+
+from fsspec.utils import Sequence
 from llama_index.core.bridge.pydantic import PrivateAttr
 from llama_index.core.schema import BaseNode, MetadataMode
 from llama_index.core.vector_stores.types import (
     BasePydanticVectorStore,
+    MetadataFilter,
     MetadataFilters,
     VectorStoreQuery,
     VectorStoreQueryResult,
@@ -17,6 +20,8 @@ from llama_index.core.vector_stores.utils import (
     node_to_metadata_dict,
 )
 
+import duckdb
+
 logger = logging.getLogger(__name__)
 import_err_msg = "`duckdb` package not found, please run `pip install duckdb`"
 
@@ -28,11 +33,6 @@ class DuckDBLocalContext:
         self._home_dir = os.path.expanduser("~")
 
     def __enter__(self) -> "duckdb.DuckDBPyConnection":
-        try:
-            import duckdb
-        except ImportError:
-            raise ImportError(import_err_msg)
-
         if not os.path.exists(os.path.dirname(self.database_path)):
             raise ValueError(
                 f"Directory {os.path.dirname(self.database_path)} does not exist."
@@ -51,14 +51,13 @@ class DuckDBLocalContext:
         return self._conn
 
     def __exit__(self, exc_type, exc_val, exc_tb) -> None:
-        self._conn.close()
-
         if self._conn:
             self._conn.close()
 
 
 class DuckDBVectorStore(BasePydanticVectorStore):
-    """DuckDB vector store.
+    """
+    DuckDB vector store.
 
     In this vector store, embeddings are stored within a DuckDB database.
 
@@ -83,7 +82,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
     flat_metadata: bool = True
 
     database_name: Optional[str]
-    table_name: Optional[str]
+    table_name: str
     # schema_name: Optional[str] # TODO: support schema name
     embed_dim: Optional[int]
     # hybrid_search: Optional[bool] # TODO: support hybrid search
@@ -96,11 +95,9 @@ class DuckDBVectorStore(BasePydanticVectorStore):
 
     def __init__(
         self,
-        database_name: Optional[str] = ":memory:",
-        table_name: Optional[str] = "documents",
-        # schema_name: Optional[str] = "main",
+        database_name: str = ":memory:",
+        table_name: str = "documents",
         embed_dim: Optional[int] = None,
-        # hybrid_search: Optional[bool] = False,
         # https://duckdb.org/docs/extensions/full_text_search
         text_search_config: Optional[dict] = {
             "stemmer": "english",
@@ -110,7 +107,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
             "lower": True,
             "overwrite": False,
         },
-        persist_dir: Optional[str] = "./storage",
+        persist_dir: str = "./storage",
         **kwargs: Any,
     ) -> None:
         """Init params."""
@@ -140,15 +137,14 @@ class DuckDBVectorStore(BasePydanticVectorStore):
 
             conn = None
 
-        super().__init__(
-            database_name=database_name,
-            table_name=table_name,
-            # schema_name=schema_name,
-            embed_dim=embed_dim,
-            # hybrid_search=hybrid_search,
-            text_search_config=text_search_config,
-            persist_dir=persist_dir,
-        )
+        fields = {
+            "database_name": database_name,
+            "table_name": table_name,
+            "embed_dim": embed_dim,
+            "text_search_config": text_search_config,
+            "persist_dir": persist_dir,
+        }
+        super().__init__(stores_text=True, **fields)
         self._is_initialized = False
         self._conn = conn
         self._database_path = database_path
@@ -157,7 +153,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
     def from_local(
         cls,
         database_path: str,
-        table_name: Optional[str] = "documents",
+        table_name: str = "documents",
         # schema_name: Optional[str] = "main",
         embed_dim: Optional[int] = None,
         # hybrid_search: Optional[bool] = False,
@@ -201,8 +197,8 @@ class DuckDBVectorStore(BasePydanticVectorStore):
     @classmethod
     def from_params(
         cls,
-        database_name: Optional[str] = ":memory:",
-        table_name: Optional[str] = "documents",
+        database_name: str = ":memory:",
+        table_name: str = "documents",
         # schema_name: Optional[str] = "main",
         embed_dim: Optional[int] = None,
         # hybrid_search: Optional[bool] = False,
@@ -214,7 +210,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
             "lower": True,
             "overwrite": False,
         },
-        persist_dir: Optional[str] = "./storage",
+        persist_dir: str = "./storage",
         **kwargs: Any,
     ) -> "DuckDBVectorStore":
         return cls(
@@ -263,7 +259,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
 
             if self.database_name == ":memory:":
                 self._conn.execute(_query)
-            else:
+            elif self._database_path is not None:
                 with DuckDBLocalContext(self._database_path) as _conn:
                     _conn.execute(_query)
 
@@ -284,8 +280,9 @@ class DuckDBVectorStore(BasePydanticVectorStore):
     def _table_row_to_node(self, row: Any) -> BaseNode:
         return metadata_dict_to_node(json.loads(row[3]), row[1])
 
-    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
-        """Add nodes to index.
+    def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]:
+        """
+        Add nodes to index.
 
         Args:
             nodes: List[BaseNode]: list of nodes with embeddings
@@ -301,7 +298,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
                 ids.append(node.node_id)
                 _row = self._node_to_table_row(node)
                 _table.insert(_row)
-        else:
+        elif self._database_path is not None:
             with DuckDBLocalContext(self._database_path) as _conn:
                 _table = _conn.table(self.table_name)
                 for node in nodes:
@@ -321,18 +318,18 @@ class DuckDBVectorStore(BasePydanticVectorStore):
         """
         _ddb_query = f"""
             DELETE FROM {self.table_name}
-            WHERE json_extract_string(metadata_, '$.ref_doc_id') = '{ref_doc_id}';
+            WHERE json_extract_string(metadata_, '$.ref_doc_id') = ?;
             """
         if self.database_name == ":memory:":
-            self._conn.execute(_ddb_query)
-        else:
+            self._conn.execute(_ddb_query, [ref_doc_id])
+        elif self._database_path is not None:
             with DuckDBLocalContext(self._database_path) as _conn:
-                _conn.execute(_ddb_query)
+                _conn.execute(_ddb_query, [ref_doc_id])
 
     @staticmethod
     def _build_metadata_filter_condition(
         standard_filters: MetadataFilters,
-    ) -> dict:
+    ) -> str:
         """Translate standard metadata filters to DuckDB SQL specification."""
         filters_list = []
         # condition = standard_filters.condition or "and"  ## and/or as strings.
@@ -340,6 +337,7 @@ class DuckDBVectorStore(BasePydanticVectorStore):
         _filters_condition_list = []
 
         for filter in standard_filters.filters:
+            filter = cast(MetadataFilter, filter)
             if filter.operator:
                 if filter.operator in [
                     "<",
@@ -372,7 +370,8 @@ class DuckDBVectorStore(BasePydanticVectorStore):
         return f" {condition} ".join(_filters_condition_list)
 
     def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
-        """Query index for top k most similar nodes.
+        """
+        Query index for top k most similar nodes.
 
         Args:
             query.query_embedding (List[float]): query embedding
@@ -389,29 +388,39 @@ class DuckDBVectorStore(BasePydanticVectorStore):
             _ddb_query = f"""
             SELECT node_id, text, embedding, metadata_, score
             FROM (
-                SELECT *, list_cosine_similarity(embedding, {query.query_embedding}) AS score
+                SELECT *, list_cosine_similarity(embedding, ?) AS score
                 FROM {self.table_name}
-                WHERE {_filter_string}
+                WHERE ?
             ) sq
             WHERE score IS NOT NULL
-            ORDER BY score DESC LIMIT {query.similarity_top_k};
+            ORDER BY score DESC LIMIT ?;
             """
+            query_params = [
+                query.query_embedding,
+                _filter_string,
+                query.similarity_top_k,
+            ]
         else:
             _ddb_query = f"""
             SELECT node_id, text, embedding, metadata_, score
             FROM (
-                SELECT *, list_cosine_similarity(embedding, {query.query_embedding}) AS score
+                SELECT *, list_cosine_similarity(embedding, ?) AS score
                 FROM {self.table_name}
             ) sq
             WHERE score IS NOT NULL
-            ORDER BY score DESC LIMIT {query.similarity_top_k};
+            ORDER BY score DESC LIMIT ?;
             """
+            query_params = [
+                query.query_embedding,
+                query.similarity_top_k,
+            ]
 
+        _final_results = []
         if self.database_name == ":memory:":
-            _final_results = self._conn.execute(_ddb_query).fetchall()
-        else:
+            _final_results = self._conn.execute(_ddb_query, query_params).fetchall()
+        elif self._database_path is not None:
             with DuckDBLocalContext(self._database_path) as _conn:
-                _final_results = _conn.execute(_ddb_query).fetchall()
+                _final_results = _conn.execute(_ddb_query, query_params).fetchall()
 
         for _row in _final_results:
             node = self._table_row_to_node(_row)
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
index efaeccdb8d..481b546e47 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml
@@ -18,7 +18,7 @@ DuckDBVectorStore = "krish-adi"
 disallow_untyped_defs = true
 exclude = ["_static", "build", "examples", "notebooks", "venv"]
 ignore_missing_imports = true
-python_version = "3.8"
+python_version = "3.9"
 
 [tool.poetry]
 authors = ["Adithya Krishnan <me@krishadi.com>"]
@@ -28,7 +28,7 @@ license = "MIT"
 maintainers = ["krish-adi"]
 name = "llama-index-vector-stores-duckdb"
 readme = "README.md"
-version = "0.3.0"
+version = "0.3.1"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
-- 
GitLab