From 1f943903f1515b61774444472290144b106bff1c Mon Sep 17 00:00:00 2001 From: Pavan Ramkumar <pavan.ramkumar@system1.bio> Date: Fri, 2 Feb 2024 10:06:11 -0800 Subject: [PATCH] LanceDBVectorStore bug fixes (#10404) --- llama_index/vector_stores/lancedb.py | 34 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/llama_index/vector_stores/lancedb.py b/llama_index/vector_stores/lancedb.py index 288b7d138..087a9ff63 100644 --- a/llama_index/vector_stores/lancedb.py +++ b/llama_index/vector_stores/lancedb.py @@ -19,6 +19,7 @@ from llama_index.vector_stores.types import ( VectorStoreQueryResult, ) from llama_index.vector_stores.utils import ( + DEFAULT_DOC_ID_KEY, DEFAULT_TEXT_KEY, legacy_metadata_dict_to_node, metadata_dict_to_node, @@ -52,7 +53,8 @@ def _to_llama_similarities(results: DataFrame) -> List[float]: class LanceDBVectorStore(VectorStore): - """The LanceDB Vector Store. + """ + The LanceDB Vector Store. Stores text and embeddings in LanceDB. The vector store will open an existing LanceDB dataset or create the dataset if it does not exist. @@ -61,6 +63,8 @@ class LanceDBVectorStore(VectorStore): uri (str, required): Location where LanceDB will store its files. table_name (str, optional): The table name where the embeddings will be stored. Defaults to "vectors". + vector_column_name (str, optional): The vector column name in the table if different from default. + Defaults to "vector", in keeping with lancedb convention. nprobes (int, optional): The number of probes used. A higher number makes search more accurate but also slower. Defaults to 20. @@ -83,9 +87,11 @@ class LanceDBVectorStore(VectorStore): self, uri: str, table_name: str = "vectors", + vector_column_name: str = "vector", nprobes: int = 20, refine_factor: Optional[int] = None, text_key: str = DEFAULT_TEXT_KEY, + doc_id_key: str = DEFAULT_DOC_ID_KEY, **kwargs: Any, ) -> None: """Init params.""" @@ -98,8 +104,10 @@ class LanceDBVectorStore(VectorStore): self.connection = lancedb.connect(uri) self.uri = uri self.table_name = table_name + self.vector_column_name = vector_column_name self.nprobes = nprobes self.text_key = text_key + self.doc_id_key = doc_id_key self.refine_factor = refine_factor @property @@ -165,7 +173,10 @@ class LanceDBVectorStore(VectorStore): table = self.connection.open_table(self.table_name) lance_query = ( - table.search(query.query_embedding) + table.search( + query=query.query_embedding, + vector_column_name=self.vector_column_name, + ) .limit(query.similarity_top_k) .where(where) .nprobes(self.nprobes) @@ -174,28 +185,33 @@ class LanceDBVectorStore(VectorStore): if self.refine_factor is not None: lance_query.refine_factor(self.refine_factor) - results = lance_query.to_df() + results = lance_query.to_pandas() nodes = [] for _, item in results.iterrows(): try: node = metadata_dict_to_node(item.metadata) - node.embedding = list(item.vector) + node.embedding = list(item[self.vector_column_name]) except Exception: # deprecated legacy logic for backward compatibility _logger.debug( "Failed to parse Node metadata, fallback to legacy logic." ) - metadata, node_info, _relation = legacy_metadata_dict_to_node( - item.metadata, text_key=self.text_key - ) + if "metadata" in item: + metadata, node_info, _relation = legacy_metadata_dict_to_node( + item.metadata, text_key=self.text_key + ) + else: + metadata, node_info = {}, {} node = TextNode( - text=item.text or "", + text=item[self.text_key] or "", id_=item.id, metadata=metadata, start_char_idx=node_info.get("start", None), end_char_idx=node_info.get("end", None), relationships={ - NodeRelationship.SOURCE: RelatedNodeInfo(node_id=item.doc_id), + NodeRelationship.SOURCE: RelatedNodeInfo( + node_id=item[self.doc_id_key] + ), }, ) -- GitLab