From 369a2942df2efcf6b74461c45d20a0af1fbe4ae2 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi <mpippi@gmail.com> Date: Thu, 27 Feb 2025 22:13:51 +0100 Subject: [PATCH] fix: escape params in SQL queries in DuckDB vector store (#17952) --- .../llama_index/vector_stores/duckdb/base.py | 103 ++++++++++-------- .../pyproject.toml | 4 +- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py index 5ce76c34cf..9bbc8cbcd3 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/llama_index/vector_stores/duckdb/base.py @@ -1,13 +1,16 @@ """DuckDB vector store.""" -import logging import json -from typing import Any, List, Optional +import logging import os +from typing import Any, List, Optional, cast + +from fsspec.utils import Sequence from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, + MetadataFilter, MetadataFilters, VectorStoreQuery, VectorStoreQueryResult, @@ -17,6 +20,8 @@ from llama_index.core.vector_stores.utils import ( node_to_metadata_dict, ) +import duckdb + logger = logging.getLogger(__name__) import_err_msg = "`duckdb` package not found, please run `pip install duckdb`" @@ -28,11 +33,6 @@ class DuckDBLocalContext: self._home_dir = os.path.expanduser("~") def __enter__(self) -> "duckdb.DuckDBPyConnection": - try: - import duckdb - except ImportError: - raise ImportError(import_err_msg) - if not os.path.exists(os.path.dirname(self.database_path)): raise ValueError( f"Directory {os.path.dirname(self.database_path)} does not exist." @@ -51,14 +51,13 @@ class DuckDBLocalContext: return self._conn def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self._conn.close() - if self._conn: self._conn.close() class DuckDBVectorStore(BasePydanticVectorStore): - """DuckDB vector store. + """ + DuckDB vector store. In this vector store, embeddings are stored within a DuckDB database. @@ -83,7 +82,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): flat_metadata: bool = True database_name: Optional[str] - table_name: Optional[str] + table_name: str # schema_name: Optional[str] # TODO: support schema name embed_dim: Optional[int] # hybrid_search: Optional[bool] # TODO: support hybrid search @@ -96,11 +95,9 @@ class DuckDBVectorStore(BasePydanticVectorStore): def __init__( self, - database_name: Optional[str] = ":memory:", - table_name: Optional[str] = "documents", - # schema_name: Optional[str] = "main", + database_name: str = ":memory:", + table_name: str = "documents", embed_dim: Optional[int] = None, - # hybrid_search: Optional[bool] = False, # https://duckdb.org/docs/extensions/full_text_search text_search_config: Optional[dict] = { "stemmer": "english", @@ -110,7 +107,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): "lower": True, "overwrite": False, }, - persist_dir: Optional[str] = "./storage", + persist_dir: str = "./storage", **kwargs: Any, ) -> None: """Init params.""" @@ -140,15 +137,14 @@ class DuckDBVectorStore(BasePydanticVectorStore): conn = None - super().__init__( - database_name=database_name, - table_name=table_name, - # schema_name=schema_name, - embed_dim=embed_dim, - # hybrid_search=hybrid_search, - text_search_config=text_search_config, - persist_dir=persist_dir, - ) + fields = { + "database_name": database_name, + "table_name": table_name, + "embed_dim": embed_dim, + "text_search_config": text_search_config, + "persist_dir": persist_dir, + } + super().__init__(stores_text=True, **fields) self._is_initialized = False self._conn = conn self._database_path = database_path @@ -157,7 +153,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): def from_local( cls, database_path: str, - table_name: Optional[str] = "documents", + table_name: str = "documents", # schema_name: Optional[str] = "main", embed_dim: Optional[int] = None, # hybrid_search: Optional[bool] = False, @@ -201,8 +197,8 @@ class DuckDBVectorStore(BasePydanticVectorStore): @classmethod def from_params( cls, - database_name: Optional[str] = ":memory:", - table_name: Optional[str] = "documents", + database_name: str = ":memory:", + table_name: str = "documents", # schema_name: Optional[str] = "main", embed_dim: Optional[int] = None, # hybrid_search: Optional[bool] = False, @@ -214,7 +210,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): "lower": True, "overwrite": False, }, - persist_dir: Optional[str] = "./storage", + persist_dir: str = "./storage", **kwargs: Any, ) -> "DuckDBVectorStore": return cls( @@ -263,7 +259,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): if self.database_name == ":memory:": self._conn.execute(_query) - else: + elif self._database_path is not None: with DuckDBLocalContext(self._database_path) as _conn: _conn.execute(_query) @@ -284,8 +280,9 @@ class DuckDBVectorStore(BasePydanticVectorStore): def _table_row_to_node(self, row: Any) -> BaseNode: return metadata_dict_to_node(json.loads(row[3]), row[1]) - def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: - """Add nodes to index. + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + """ + Add nodes to index. Args: nodes: List[BaseNode]: list of nodes with embeddings @@ -301,7 +298,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): ids.append(node.node_id) _row = self._node_to_table_row(node) _table.insert(_row) - else: + elif self._database_path is not None: with DuckDBLocalContext(self._database_path) as _conn: _table = _conn.table(self.table_name) for node in nodes: @@ -321,18 +318,18 @@ class DuckDBVectorStore(BasePydanticVectorStore): """ _ddb_query = f""" DELETE FROM {self.table_name} - WHERE json_extract_string(metadata_, '$.ref_doc_id') = '{ref_doc_id}'; + WHERE json_extract_string(metadata_, '$.ref_doc_id') = ?; """ if self.database_name == ":memory:": - self._conn.execute(_ddb_query) - else: + self._conn.execute(_ddb_query, [ref_doc_id]) + elif self._database_path is not None: with DuckDBLocalContext(self._database_path) as _conn: - _conn.execute(_ddb_query) + _conn.execute(_ddb_query, [ref_doc_id]) @staticmethod def _build_metadata_filter_condition( standard_filters: MetadataFilters, - ) -> dict: + ) -> str: """Translate standard metadata filters to DuckDB SQL specification.""" filters_list = [] # condition = standard_filters.condition or "and" ## and/or as strings. @@ -340,6 +337,7 @@ class DuckDBVectorStore(BasePydanticVectorStore): _filters_condition_list = [] for filter in standard_filters.filters: + filter = cast(MetadataFilter, filter) if filter.operator: if filter.operator in [ "<", @@ -372,7 +370,8 @@ class DuckDBVectorStore(BasePydanticVectorStore): return f" {condition} ".join(_filters_condition_list) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. + """ + Query index for top k most similar nodes. Args: query.query_embedding (List[float]): query embedding @@ -389,29 +388,39 @@ class DuckDBVectorStore(BasePydanticVectorStore): _ddb_query = f""" SELECT node_id, text, embedding, metadata_, score FROM ( - SELECT *, list_cosine_similarity(embedding, {query.query_embedding}) AS score + SELECT *, list_cosine_similarity(embedding, ?) AS score FROM {self.table_name} - WHERE {_filter_string} + WHERE ? ) sq WHERE score IS NOT NULL - ORDER BY score DESC LIMIT {query.similarity_top_k}; + ORDER BY score DESC LIMIT ?; """ + query_params = [ + query.query_embedding, + _filter_string, + query.similarity_top_k, + ] else: _ddb_query = f""" SELECT node_id, text, embedding, metadata_, score FROM ( - SELECT *, list_cosine_similarity(embedding, {query.query_embedding}) AS score + SELECT *, list_cosine_similarity(embedding, ?) AS score FROM {self.table_name} ) sq WHERE score IS NOT NULL - ORDER BY score DESC LIMIT {query.similarity_top_k}; + ORDER BY score DESC LIMIT ?; """ + query_params = [ + query.query_embedding, + query.similarity_top_k, + ] + _final_results = [] if self.database_name == ":memory:": - _final_results = self._conn.execute(_ddb_query).fetchall() - else: + _final_results = self._conn.execute(_ddb_query, query_params).fetchall() + elif self._database_path is not None: with DuckDBLocalContext(self._database_path) as _conn: - _final_results = _conn.execute(_ddb_query).fetchall() + _final_results = _conn.execute(_ddb_query, query_params).fetchall() for _row in _final_results: node = self._table_row_to_node(_row) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml index efaeccdb8d..481b546e47 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-duckdb/pyproject.toml @@ -18,7 +18,7 @@ DuckDBVectorStore = "krish-adi" disallow_untyped_defs = true exclude = ["_static", "build", "examples", "notebooks", "venv"] ignore_missing_imports = true -python_version = "3.8" +python_version = "3.9" [tool.poetry] authors = ["Adithya Krishnan <me@krishadi.com>"] @@ -28,7 +28,7 @@ license = "MIT" maintainers = ["krish-adi"] name = "llama-index-vector-stores-duckdb" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.9,<4.0" -- GitLab