Skip to content
Snippets Groups Projects
Commit 90fe4d11 authored by James Briggs's avatar James Briggs
Browse files

feat: modify index readiness checks

parent e12f8ebc
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,12 @@ from semantic_router.utils.logger import logger ...@@ -15,6 +15,12 @@ from semantic_router.utils.logger import logger
RETRY_WAIT_TIME = 2.5 RETRY_WAIT_TIME = 2.5
class IndexConfig(BaseModel):
type: str
dimensions: int
vectors: int
class BaseIndex(BaseModel): class BaseIndex(BaseModel):
""" """
Base class for indices using Pydantic's BaseModel. Base class for indices using Pydantic's BaseModel.
...@@ -146,10 +152,10 @@ class BaseIndex(BaseModel): ...@@ -146,10 +152,10 @@ class BaseIndex(BaseModel):
""" """
raise NotImplementedError("This method should be implemented by subclasses.") raise NotImplementedError("This method should be implemented by subclasses.")
def describe(self) -> Dict: def describe(self) -> IndexConfig:
""" """
Returns a dictionary with index details such as type, dimensions, and total Returns an IndexConfig object with index details such as type, dimensions, and
vector count. total vector count.
This method should be implemented by subclasses. This method should be implemented by subclasses.
""" """
raise NotImplementedError("This method should be implemented by subclasses.") raise NotImplementedError("This method should be implemented by subclasses.")
......
...@@ -67,13 +67,6 @@ class HybridLocalIndex(LocalIndex): ...@@ -67,13 +67,6 @@ class HybridLocalIndex(LocalIndex):
return [] return []
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)] return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
def describe(self) -> Dict:
return {
"type": self.type,
"dimensions": self.index.shape[1] if self.index is not None else 0,
"vectors": self.index.shape[0] if self.index is not None else 0,
}
def _sparse_dot_product( def _sparse_dot_product(
self, vec_a: dict[int, float], vec_b: dict[int, float] self, vec_a: dict[int, float], vec_b: dict[int, float]
) -> float: ) -> float:
......
...@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Dict ...@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Dict
import numpy as np import numpy as np
from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance from semantic_router.schema import ConfigParameter, SparseEmbedding, Utterance
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex, IndexConfig
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from typing import Any from typing import Any
...@@ -75,12 +75,12 @@ class LocalIndex(BaseIndex): ...@@ -75,12 +75,12 @@ class LocalIndex(BaseIndex):
return [] return []
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)] return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
def describe(self) -> Dict: def describe(self) -> IndexConfig:
return { return IndexConfig(
"type": self.type, type=self.type,
"dimensions": self.index.shape[1] if self.index is not None else 0, dimensions=self.index.shape[1] if self.index is not None else 0,
"vectors": self.index.shape[0] if self.index is not None else 0, vectors=self.index.shape[0] if self.index is not None else 0,
} )
def query( def query(
self, self,
......
...@@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Union, Tuple ...@@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Union, Tuple
import numpy as np import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex, IndexConfig
from semantic_router.schema import ConfigParameter, SparseEmbedding from semantic_router.schema import ConfigParameter, SparseEmbedding
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -449,16 +449,20 @@ class PineconeIndex(BaseIndex): ...@@ -449,16 +449,20 @@ class PineconeIndex(BaseIndex):
def delete_all(self): def delete_all(self):
self.index.delete(delete_all=True, namespace=self.namespace) self.index.delete(delete_all=True, namespace=self.namespace)
def describe(self) -> Dict: def describe(self) -> IndexConfig:
if self.index is not None: if self.index is not None:
stats = self.index.describe_index_stats() stats = self.index.describe_index_stats()
return { return IndexConfig(
"type": self.type, type=self.type,
"dimensions": stats["dimension"], dimensions=stats["dimension"],
"vectors": stats["namespaces"][self.namespace]["vector_count"], vectors=stats["namespaces"][self.namespace]["vector_count"],
} )
else: else:
raise ValueError("Index is None, cannot describe index stats.") return IndexConfig(
type=self.type,
dimensions=self.dimensions or 0,
vectors=0,
)
def query( def query(
self, self,
......
...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union ...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex, IndexConfig
from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -324,17 +324,21 @@ class PostgresIndex(BaseIndex): ...@@ -324,17 +324,21 @@ class PostgresIndex(BaseIndex):
cur.execute(f"DELETE FROM {table_name} WHERE route = '{route_name}'") cur.execute(f"DELETE FROM {table_name} WHERE route = '{route_name}'")
self.conn.commit() self.conn.commit()
def describe(self) -> Dict: def describe(self) -> IndexConfig:
""" """
Describes the index by returning its type, dimensions, and total vector count. Describes the index by returning its type, dimensions, and total vector count.
:return: A dictionary containing the index's type, dimensions, and total vector count. :return: An IndexConfig object containing the index's type, dimensions, and total vector count.
:rtype: Dict :rtype: IndexConfig
:raises TypeError: If the database connection is not established.
""" """
table_name = self._get_table_name() table_name = self._get_table_name()
if not isinstance(self.conn, psycopg2.extensions.connection): if not isinstance(self.conn, psycopg2.extensions.connection):
raise TypeError("Index has not established a connection to Postgres") logger.warning("Index has not established a connection to Postgres")
return IndexConfig(
type=self.type,
dimensions=self.dimensions or 0,
vectors=0,
)
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute(f"SELECT COUNT(*) FROM {table_name}") cur.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cur.fetchone() count = cur.fetchone()
...@@ -342,11 +346,11 @@ class PostgresIndex(BaseIndex): ...@@ -342,11 +346,11 @@ class PostgresIndex(BaseIndex):
count = 0 count = 0
else: else:
count = count[0] # Extract the actual count from the tuple count = count[0] # Extract the actual count from the tuple
return { return IndexConfig(
"type": self.type, type=self.type,
"dimensions": self.dimensions, dimensions=self.dimensions or 0,
"total_vector_count": count, vectors=count,
} )
def query( def query(
self, self,
......
...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex, IndexConfig
from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding, Utterance from semantic_router.schema import ConfigParameter, Metric, SparseEmbedding, Utterance
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -246,14 +246,14 @@ class QdrantIndex(BaseIndex): ...@@ -246,14 +246,14 @@ class QdrantIndex(BaseIndex):
), ),
) )
def describe(self) -> Dict: def describe(self) -> IndexConfig:
collection_info = self.client.get_collection(self.index_name) collection_info = self.client.get_collection(self.index_name)
return { return IndexConfig(
"type": self.type, type=self.type,
"dimensions": collection_info.config.params.vectors.size, dimensions=collection_info.config.params.vectors.size,
"vectors": collection_info.points_count, vectors=collection_info.points_count,
} )
def query( def query(
self, self,
......
...@@ -15,6 +15,7 @@ from semantic_router.encoders import AutoEncoder, DenseEncoder, OpenAIEncoder ...@@ -15,6 +15,7 @@ from semantic_router.encoders import AutoEncoder, DenseEncoder, OpenAIEncoder
from semantic_router.index.base import BaseIndex from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex
from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import ( from semantic_router.schema import (
...@@ -421,7 +422,8 @@ class BaseRouter(BaseModel): ...@@ -421,7 +422,8 @@ class BaseRouter(BaseModel):
simulate_static: bool = False, simulate_static: bool = False,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
) -> RouteChoice: ) -> RouteChoice:
if self.index.index is None or self.routes is None: ready = self._index_ready()
if not ready:
raise ValueError("Index or routes are not populated.") raise ValueError("Index or routes are not populated.")
# if no vector provided, encode text to get vector # if no vector provided, encode text to get vector
if vector is None: if vector is None:
...@@ -479,7 +481,8 @@ class BaseRouter(BaseModel): ...@@ -479,7 +481,8 @@ class BaseRouter(BaseModel):
simulate_static: bool = False, simulate_static: bool = False,
route_filter: Optional[List[str]] = None, route_filter: Optional[List[str]] = None,
) -> RouteChoice: ) -> RouteChoice:
if self.index.index is None or self.routes is None: ready = self._index_ready() # TODO: need async version for qdrant
if not ready:
raise ValueError("Index or routes are not populated.") raise ValueError("Index or routes are not populated.")
# if no vector provided, encode text to get vector # if no vector provided, encode text to get vector
if vector is None: if vector is None:
...@@ -527,6 +530,20 @@ class BaseRouter(BaseModel): ...@@ -527,6 +530,20 @@ class BaseRouter(BaseModel):
# if no route passes threshold, return empty route choice # if no route passes threshold, return empty route choice
return RouteChoice() return RouteChoice()
def _index_ready(self) -> bool:
"""Method to check if the index is ready to be used.
:return: True if the index is ready, False otherwise.
:rtype: bool
"""
if self.index.index is None or self.routes is None:
return False
if isinstance(self.index, QdrantIndex):
info = self.index.describe()
if info.vectors == 0:
return False
return True
def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
"""Runs a sync of the local routes with the remote index. """Runs a sync of the local routes with the remote index.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment