From 1f943903f1515b61774444472290144b106bff1c Mon Sep 17 00:00:00 2001
From: Pavan Ramkumar <pavan.ramkumar@system1.bio>
Date: Fri, 2 Feb 2024 10:06:11 -0800
Subject: [PATCH] LanceDBVectorStore bug fixes (#10404)

---
 llama_index/vector_stores/lancedb.py | 34 ++++++++++++++++++++--------
 1 file changed, 25 insertions(+), 9 deletions(-)

diff --git a/llama_index/vector_stores/lancedb.py b/llama_index/vector_stores/lancedb.py
index 288b7d138..087a9ff63 100644
--- a/llama_index/vector_stores/lancedb.py
+++ b/llama_index/vector_stores/lancedb.py
@@ -19,6 +19,7 @@ from llama_index.vector_stores.types import (
     VectorStoreQueryResult,
 )
 from llama_index.vector_stores.utils import (
+    DEFAULT_DOC_ID_KEY,
     DEFAULT_TEXT_KEY,
     legacy_metadata_dict_to_node,
     metadata_dict_to_node,
@@ -52,7 +53,8 @@ def _to_llama_similarities(results: DataFrame) -> List[float]:
 
 
 class LanceDBVectorStore(VectorStore):
-    """The LanceDB Vector Store.
+    """
+    The LanceDB Vector Store.
 
     Stores text and embeddings in LanceDB. The vector store will open an existing
         LanceDB dataset or create the dataset if it does not exist.
@@ -61,6 +63,8 @@ class LanceDBVectorStore(VectorStore):
         uri (str, required): Location where LanceDB will store its files.
         table_name (str, optional): The table name where the embeddings will be stored.
             Defaults to "vectors".
+        vector_column_name (str, optional): The vector column name in the table if different from default.
+            Defaults to "vector", in keeping with lancedb convention.
         nprobes (int, optional): The number of probes used.
             A higher number makes search more accurate but also slower.
             Defaults to 20.
@@ -83,9 +87,11 @@ class LanceDBVectorStore(VectorStore):
         self,
         uri: str,
         table_name: str = "vectors",
+        vector_column_name: str = "vector",
         nprobes: int = 20,
         refine_factor: Optional[int] = None,
         text_key: str = DEFAULT_TEXT_KEY,
+        doc_id_key: str = DEFAULT_DOC_ID_KEY,
         **kwargs: Any,
     ) -> None:
         """Init params."""
@@ -98,8 +104,10 @@ class LanceDBVectorStore(VectorStore):
         self.connection = lancedb.connect(uri)
         self.uri = uri
         self.table_name = table_name
+        self.vector_column_name = vector_column_name
         self.nprobes = nprobes
         self.text_key = text_key
+        self.doc_id_key = doc_id_key
         self.refine_factor = refine_factor
 
     @property
@@ -165,7 +173,10 @@ class LanceDBVectorStore(VectorStore):
 
         table = self.connection.open_table(self.table_name)
         lance_query = (
-            table.search(query.query_embedding)
+            table.search(
+                query=query.query_embedding,
+                vector_column_name=self.vector_column_name,
+            )
             .limit(query.similarity_top_k)
             .where(where)
             .nprobes(self.nprobes)
@@ -174,28 +185,33 @@ class LanceDBVectorStore(VectorStore):
         if self.refine_factor is not None:
             lance_query.refine_factor(self.refine_factor)
 
-        results = lance_query.to_df()
+        results = lance_query.to_pandas()
         nodes = []
         for _, item in results.iterrows():
             try:
                 node = metadata_dict_to_node(item.metadata)
-                node.embedding = list(item.vector)
+                node.embedding = list(item[self.vector_column_name])
             except Exception:
                 # deprecated legacy logic for backward compatibility
                 _logger.debug(
                     "Failed to parse Node metadata, fallback to legacy logic."
                 )
-                metadata, node_info, _relation = legacy_metadata_dict_to_node(
-                    item.metadata, text_key=self.text_key
-                )
+                if "metadata" in item:
+                    metadata, node_info, _relation = legacy_metadata_dict_to_node(
+                        item.metadata, text_key=self.text_key
+                    )
+                else:
+                    metadata, node_info = {}, {}
                 node = TextNode(
-                    text=item.text or "",
+                    text=item[self.text_key] or "",
                     id_=item.id,
                     metadata=metadata,
                     start_char_idx=node_info.get("start", None),
                     end_char_idx=node_info.get("end", None),
                     relationships={
-                        NodeRelationship.SOURCE: RelatedNodeInfo(node_id=item.doc_id),
+                        NodeRelationship.SOURCE: RelatedNodeInfo(
+                            node_id=item[self.doc_id_key]
+                        ),
                     },
                 )
 
-- 
GitLab