From b8d44aaf1d48177019b4ae9795599ba627a38534 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Thu, 28 Nov 2024 00:40:18 +0100 Subject: [PATCH] chore: lint --- semantic_router/encoders/aurelio.py | 4 +++- semantic_router/encoders/base.py | 4 +--- semantic_router/encoders/bm25.py | 4 +++- semantic_router/encoders/tfidf.py | 4 +++- semantic_router/index/pinecone.py | 12 +++++++++--- semantic_router/routers/base.py | 15 ++++++++++++--- 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index 8824b2f1..779fe6b1 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -19,9 +19,11 @@ class AurelioSparseEncoder(SparseEncoder): def __init__( self, - name: str = "bm25", + name: str | None = None, api_key: Optional[str] = None, ): + if name is None: + name = "bm25" super().__init__(name=name) if api_key is None: api_key = os.getenv("AURELIO_API_KEY") diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index ed0eb523..0a25f210 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -35,9 +35,7 @@ class SparseEncoder(BaseModel): def __call__(self, docs: List[str]) -> List[SparseEmbedding]: raise NotImplementedError("Subclasses must implement this method") - async def acall( - self, docs: List[str] - ) -> Coroutine[Any, Any, List[SparseEmbedding]]: + async def acall(self, docs: List[str]) -> list[SparseEmbedding]: raise NotImplementedError("Subclasses must implement this method") def _array_to_sparse_embeddings( diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index e2bb24c1..3357ded8 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -12,9 +12,11 @@ class BM25Encoder(TfidfEncoder): def __init__( self, - name: str = "bm25", + name: str | None = None, use_default_params: bool = True, ): + if name is None: + name = "bm25" super().__init__(name=name) try: from pinecone_text.sparse import BM25Encoder as encoder diff --git a/semantic_router/encoders/tfidf.py b/semantic_router/encoders/tfidf.py index d9d97a47..a7ac9136 100644 --- a/semantic_router/encoders/tfidf.py +++ b/semantic_router/encoders/tfidf.py @@ -15,7 +15,9 @@ class TfidfEncoder(SparseEncoder): idf: ndarray = np.array([]) word_index: Dict = {} - def __init__(self, name: str = "tfidf"): + def __init__(self, name: str | None = None): + if name is None: + name = "tfidf" super().__init__(name=name) self.word_index = {} self.idf = np.array([]) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index bb6ed3ef..6086de1c 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -223,7 +223,7 @@ class PineconeIndex(BaseIndex): # if the index doesn't exist and we don't have the dimensions # we raise warning logger.warning("Index could not be initialized.") - self.host = index_stats["host"] if index_stats else None + self.host = index_stats["host"] if index_stats else "" def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" @@ -466,7 +466,7 @@ class PineconeIndex(BaseIndex): :rtype: Tuple[np.ndarray, List[str]] :raises ValueError: If the index is not populated. """ - if self.async_client is None or self.host is None: + if self.async_client is None or self.host == "": raise ValueError("Async client or host are not initialized.") query_vector_list = vector.tolist() if route_filter is not None: @@ -492,7 +492,7 @@ class PineconeIndex(BaseIndex): :return: A list of (route_name, utterance) objects. :rtype: List[Tuple] """ - if self.async_client is None or self.host is None: + if self.async_client is None or self.host == "": raise ValueError("Async client or host are not initialized.") return await self._async_get_routes() @@ -519,6 +519,8 @@ class PineconeIndex(BaseIndex): "top_k": top_k, "include_metadata": include_metadata, } + if self.host == "": + raise ValueError("self.host is not initialized.") async with self.async_client.post( f"https://{self.host}/query", json=params, @@ -569,6 +571,8 @@ class PineconeIndex(BaseIndex): """ if self.index is None: raise ValueError("Index is None, could not retrieve vector IDs.") + if self.host == "": + raise ValueError("self.host is not initialized.") all_vector_ids = [] next_page_token = None @@ -623,6 +627,8 @@ class PineconeIndex(BaseIndex): :return: A dictionary containing the metadata for the vector. :rtype: dict """ + if self.host == "": + raise ValueError("self.host is not initialized.") url = f"https://{self.host}/vectors/fetch" params = { diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 21ff9a32..0172e5d8 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -700,18 +700,27 @@ class BaseRouter(BaseModel): def from_json(cls, file_path: str): config = RouterConfig.from_file(file_path) encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model - return cls(encoder=encoder, routes=config.routes) + if isinstance(encoder, DenseEncoder): + return cls(encoder=encoder, routes=config.routes) + else: + raise ValueError(f"{type(encoder)} not supported for loading from JSON.") @classmethod def from_yaml(cls, file_path: str): config = RouterConfig.from_file(file_path) encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model - return cls(encoder=encoder, routes=config.routes) + if isinstance(encoder, DenseEncoder): + return cls(encoder=encoder, routes=config.routes) + else: + raise ValueError(f"{type(encoder)} not supported for loading from YAML.") @classmethod def from_config(cls, config: RouterConfig, index: Optional[BaseIndex] = None): encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model - return cls(encoder=encoder, routes=config.routes, index=index) + if isinstance(encoder, DenseEncoder): + return cls(encoder=encoder, routes=config.routes, index=index) + else: + raise ValueError(f"{type(encoder)} not supported for loading from config.") def add(self, route: Route): """Add a route to the local SemanticRouter and index. -- GitLab