From 90fe4d11b2b8389974c757bea36e7cb9c9af5505 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Mon, 6 Jan 2025 09:50:24 +0400
Subject: [PATCH] feat: modify index readiness checks

---
 semantic_router/index/base.py         | 12 +++++++++---
 semantic_router/index/hybrid_local.py |  7 -------
 semantic_router/index/local.py        | 14 +++++++-------
 semantic_router/index/pinecone.py     | 20 ++++++++++++--------
 semantic_router/index/postgres.py     | 26 +++++++++++++++-----------
 semantic_router/index/qdrant.py       | 14 +++++++-------
 semantic_router/routers/base.py       | 21 +++++++++++++++++++--
 7 files changed, 69 insertions(+), 45 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 0391e3fd..452a18c6 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -15,6 +15,12 @@ from semantic_router.utils.logger import logger
 RETRY_WAIT_TIME = 2.5
 
 
+class IndexConfig(BaseModel):
+    type: str
+    dimensions: int
+    vectors: int
+
+
 class BaseIndex(BaseModel):
     """
     Base class for indices using Pydantic's BaseModel.
@@ -146,10 +152,10 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
-    def describe(self) -> Dict:
+    def describe(self) -> IndexConfig:
         """
-        Returns a dictionary with index details such as type, dimensions, and total
-        vector count.
+        Returns an IndexConfig object with index details such as type, dimensions, and
+        total vector count.
         This method should be implemented by subclasses.
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
diff --git a/semantic_router/index/hybrid_local.py b/semantic_router/index/hybrid_local.py
index 4175eac9..cab9b982 100644
--- a/semantic_router/index/hybrid_local.py
+++ b/semantic_router/index/hybrid_local.py
@@ -67,13 +67,6 @@ class HybridLocalIndex(LocalIndex):
             return []
         return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
 
-    def describe(self) -> Dict:
-        return {
-            "type": self.type,
-            "dimensions": self.index.shape[1] if self.index is not None else 0,
-            "vectors": self.index.shape[0] if self.index is not None else 0,
-        }
-
     def _sparse_dot_product(
         self, vec_a: dict[int, float], vec_b: dict[int, float]
     ) -> float:
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index c4f14fc4..10b77bea 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Dict
 import numpy as np
 
 from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance
-from semantic_router.index.base import BaseIndex
+from semantic_router.index.base import BaseIndex, IndexConfig
 from semantic_router.linear import similarity_matrix, top_scores
 from semantic_router.utils.logger import logger
 from typing import Any
@@ -75,12 +75,12 @@ class LocalIndex(BaseIndex):
             return []
         return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
 
-    def describe(self) -> Dict:
-        return {
-            "type": self.type,
-            "dimensions": self.index.shape[1] if self.index is not None else 0,
-            "vectors": self.index.shape[0] if self.index is not None else 0,
-        }
+    def describe(self) -> IndexConfig:
+        return IndexConfig(
+            type=self.type,
+            dimensions=self.index.shape[1] if self.index is not None else 0,
+            vectors=self.index.shape[0] if self.index is not None else 0,
+        )
 
     def query(
         self,
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 17eabddd..b0706318 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Union, Tuple
 import numpy as np
 from pydantic import BaseModel, Field
 
-from semantic_router.index.base import BaseIndex
+from semantic_router.index.base import BaseIndex, IndexConfig
 from semantic_router.schema import ConfigParameter, SparseEmbedding
 from semantic_router.utils.logger import logger
 
@@ -449,16 +449,20 @@ class PineconeIndex(BaseIndex):
     def delete_all(self):
         self.index.delete(delete_all=True, namespace=self.namespace)
 
-    def describe(self) -> Dict:
+    def describe(self) -> IndexConfig:
         if self.index is not None:
             stats = self.index.describe_index_stats()
-            return {
-                "type": self.type,
-                "dimensions": stats["dimension"],
-                "vectors": stats["namespaces"][self.namespace]["vector_count"],
-            }
+            return IndexConfig(
+                type=self.type,
+                dimensions=stats["dimension"],
+                vectors=stats["namespaces"][self.namespace]["vector_count"],
+            )
         else:
-            raise ValueError("Index is None, cannot describe index stats.")
+            return IndexConfig(
+                type=self.type,
+                dimensions=self.dimensions or 0,
+                vectors=0,
+            )
 
     def query(
         self,
diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py
index 76d60d2b..54054c84 100644
--- a/semantic_router/index/postgres.py
+++ b/semantic_router/index/postgres.py
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
 import numpy as np
 from pydantic import BaseModel, Field
 
-from semantic_router.index.base import BaseIndex
+from semantic_router.index.base import BaseIndex, IndexConfig
 from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding
 from semantic_router.utils.logger import logger
 
@@ -324,17 +324,21 @@ class PostgresIndex(BaseIndex):
             cur.execute(f"DELETE FROM {table_name} WHERE route = '{route_name}'")
             self.conn.commit()
 
-    def describe(self) -> Dict:
+    def describe(self) -> IndexConfig:
         """
         Describes the index by returning its type, dimensions, and total vector count.
 
-        :return: A dictionary containing the index's type, dimensions, and total vector count.
-        :rtype: Dict
-        :raises TypeError: If the database connection is not established.
+        :return: An IndexConfig object containing the index's type, dimensions, and total vector count.
+        :rtype: IndexConfig
         """
         table_name = self._get_table_name()
         if not isinstance(self.conn, psycopg2.extensions.connection):
-            raise TypeError("Index has not established a connection to Postgres")
+            logger.warning("Index has not established a connection to Postgres")
+            return IndexConfig(
+                type=self.type,
+                dimensions=self.dimensions or 0,
+                vectors=0,
+            )
         with self.conn.cursor() as cur:
             cur.execute(f"SELECT COUNT(*) FROM {table_name}")
             count = cur.fetchone()
@@ -342,11 +346,11 @@ class PostgresIndex(BaseIndex):
                 count = 0
             else:
                 count = count[0]  # Extract the actual count from the tuple
-            return {
-                "type": self.type,
-                "dimensions": self.dimensions,
-                "total_vector_count": count,
-            }
+            return IndexConfig(
+                type=self.type,
+                dimensions=self.dimensions or 0,
+                vectors=count,
+            )
 
     def query(
         self,
diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index 51846629..5986f2c0 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 from pydantic import Field
 
-from semantic_router.index.base import BaseIndex
+from semantic_router.index.base import BaseIndex, IndexConfig
 from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding, Utterance
 from semantic_router.utils.logger import logger
 
@@ -246,14 +246,14 @@ class QdrantIndex(BaseIndex):
             ),
         )
 
-    def describe(self) -> Dict:
+    def describe(self) -> IndexConfig:
         collection_info = self.client.get_collection(self.index_name)
 
-        return {
-            "type": self.type,
-            "dimensions": collection_info.config.params.vectors.size,
-            "vectors": collection_info.points_count,
-        }
+        return IndexConfig(
+            type=self.type,
+            dimensions=collection_info.config.params.vectors.size,
+            vectors=collection_info.points_count,
+        )
 
     def query(
         self,
diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index 5bff5f5d..c551b124 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -15,6 +15,7 @@ from semantic_router.encoders import AutoEncoder, DenseEncoder, OpenAIEncoder
 from semantic_router.index.base import BaseIndex
 from semantic_router.index.local import LocalIndex
 from semantic_router.index.pinecone import PineconeIndex
+from semantic_router.index.qdrant import QdrantIndex
 from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
 from semantic_router.schema import (
@@ -421,7 +422,8 @@ class BaseRouter(BaseModel):
         simulate_static: bool = False,
         route_filter: Optional[List[str]] = None,
     ) -> RouteChoice:
-        if self.index.index is None or self.routes is None:
+        ready = self._index_ready()
+        if not ready:
             raise ValueError("Index or routes are not populated.")
         # if no vector provided, encode text to get vector
         if vector is None:
@@ -479,7 +481,8 @@ class BaseRouter(BaseModel):
         simulate_static: bool = False,
         route_filter: Optional[List[str]] = None,
     ) -> RouteChoice:
-        if self.index.index is None or self.routes is None:
+        ready = self._index_ready()  # TODO: need async version for qdrant
+        if not ready:
             raise ValueError("Index or routes are not populated.")
         # if no vector provided, encode text to get vector
         if vector is None:
@@ -527,6 +530,20 @@ class BaseRouter(BaseModel):
             # if no route passes threshold, return empty route choice
             return RouteChoice()
 
+    def _index_ready(self) -> bool:
+        """Method to check if the index is ready to be used.
+
+        :return: True if the index is ready, False otherwise.
+        :rtype: bool
+        """
+        if self.index.index is None or self.routes is None:
+            return False
+        if isinstance(self.index, QdrantIndex):
+            info = self.index.describe()
+            if info.vectors == 0:
+                return False
+        return True
+
     def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
         """Runs a sync of the local routes with the remote index.
 
-- 
GitLab