Skip to content
Snippets Groups Projects
Commit 9fc31edc authored by the-anup-das's avatar the-anup-das
Browse files

lint fix

parent 2596bf9f
No related branches found
No related tags found
No related merge requests found
...@@ -150,7 +150,7 @@ class HFEndpointEncoder(BaseEncoder): ...@@ -150,7 +150,7 @@ class HFEndpointEncoder(BaseEncoder):
score_threshold (float): A threshold value used for filtering or processing the embeddings. score_threshold (float): A threshold value used for filtering or processing the embeddings.
""" """
name: Optional[str] = "hugging_face_custom_endpoint" name: str = "hugging_face_custom_endpoint"
huggingface_url: Optional[str] = None huggingface_url: Optional[str] = None
huggingface_api_key: Optional[str] = None huggingface_api_key: Optional[str] = None
score_threshold: float = 0.8 score_threshold: float = 0.8
...@@ -181,17 +181,15 @@ class HFEndpointEncoder(BaseEncoder): ...@@ -181,17 +181,15 @@ class HFEndpointEncoder(BaseEncoder):
huggingface_url = huggingface_url or os.getenv("HF_API_URL") huggingface_url = huggingface_url or os.getenv("HF_API_URL")
huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY") huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
super().__init__(name=name, score_threshold=score_threshold) # type: ignore
if huggingface_url is None: if huggingface_url is None:
raise ValueError("HuggingFace endpoint url cannot be 'None'.") raise ValueError("HuggingFace endpoint url cannot be 'None'.")
if huggingface_api_key is None: if huggingface_api_key is None:
raise ValueError("HuggingFace API key cannot be 'None'.") raise ValueError("HuggingFace API key cannot be 'None'.")
super().__init__( self.huggingface_url = huggingface_url or os.getenv("HF_API_URL")
name=name, self.huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
huggingface_url=huggingface_url,
huggingface_api_key=huggingface_api_key,
score_threshold=score_threshold,
)
try: try:
self.query({"inputs": "Hello World!", "parameters": {}}) self.query({"inputs": "Hello World!", "parameters": {}})
......
...@@ -35,7 +35,7 @@ class TfidfEncoder(BaseEncoder): ...@@ -35,7 +35,7 @@ class TfidfEncoder(BaseEncoder):
docs = [] docs = []
for route in routes: for route in routes:
for doc in route.utterances: for doc in route.utterances:
docs.append(self._preprocess(doc)) docs.append(self._preprocess(doc)) # type: ignore
self.word_index = self._build_word_index(docs) self.word_index = self._build_word_index(docs)
self.idf = self._compute_idf(docs) self.idf = self._compute_idf(docs)
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import yaml import yaml # type: ignore
from tqdm.auto import tqdm from tqdm.auto import tqdm
from semantic_router.encoders import BaseEncoder, OpenAIEncoder from semantic_router.encoders import BaseEncoder, OpenAIEncoder
...@@ -328,7 +328,7 @@ class RouteLayer: ...@@ -328,7 +328,7 @@ class RouteLayer:
def add(self, route: Route): def add(self, route: Route):
logger.info(f"Adding `{route.name}` route") logger.info(f"Adding `{route.name}` route")
# create embeddings # create embeddings
embeds = self.encoder(route.utterances) embeds = self.encoder(route.utterances) # type: ignore
# if route has no score_threshold, use default # if route has no score_threshold, use default
if route.score_threshold is None: if route.score_threshold is None:
route.score_threshold = self.score_threshold route.score_threshold = self.score_threshold
...@@ -337,7 +337,7 @@ class RouteLayer: ...@@ -337,7 +337,7 @@ class RouteLayer:
self.index.add( self.index.add(
embeddings=embeds, embeddings=embeds,
routes=[route.name] * len(route.utterances), routes=[route.name] * len(route.utterances),
utterances=route.utterances, utterances=route.utterances, # type: ignore
) )
self.routes.append(route) self.routes.append(route)
...@@ -383,14 +383,14 @@ class RouteLayer: ...@@ -383,14 +383,14 @@ class RouteLayer:
all_utterances = [ all_utterances = [
utterance for route in routes for utterance in route.utterances utterance for route in routes for utterance in route.utterances
] ]
embedded_utterances = self.encoder(all_utterances) embedded_utterances = self.encoder(all_utterances) # type: ignore
# create route array # create route array
route_names = [route.name for route in routes for _ in route.utterances] route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index # add everything to the index
self.index.add( self.index.add(
embeddings=embedded_utterances, embeddings=embedded_utterances,
routes=route_names, routes=route_names,
utterances=all_utterances, utterances=all_utterances, # type: ignore
) )
def _encode(self, text: str) -> Any: def _encode(self, text: str) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment