Skip to content
Snippets Groups Projects
Commit 9fc31edc authored by the-anup-das's avatar the-anup-das
Browse files

lint fix

parent 2596bf9f
No related branches found
No related tags found
No related merge requests found
......@@ -150,7 +150,7 @@ class HFEndpointEncoder(BaseEncoder):
score_threshold (float): A threshold value used for filtering or processing the embeddings.
"""
name: Optional[str] = "hugging_face_custom_endpoint"
name: str = "hugging_face_custom_endpoint"
huggingface_url: Optional[str] = None
huggingface_api_key: Optional[str] = None
score_threshold: float = 0.8
......@@ -181,17 +181,15 @@ class HFEndpointEncoder(BaseEncoder):
huggingface_url = huggingface_url or os.getenv("HF_API_URL")
huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
super().__init__(name=name, score_threshold=score_threshold) # type: ignore
if huggingface_url is None:
raise ValueError("HuggingFace endpoint url cannot be 'None'.")
if huggingface_api_key is None:
raise ValueError("HuggingFace API key cannot be 'None'.")
super().__init__(
name=name,
huggingface_url=huggingface_url,
huggingface_api_key=huggingface_api_key,
score_threshold=score_threshold,
)
self.huggingface_url = huggingface_url or os.getenv("HF_API_URL")
self.huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
try:
self.query({"inputs": "Hello World!", "parameters": {}})
......
......@@ -35,7 +35,7 @@ class TfidfEncoder(BaseEncoder):
docs = []
for route in routes:
for doc in route.utterances:
docs.append(self._preprocess(doc))
docs.append(self._preprocess(doc)) # type: ignore
self.word_index = self._build_word_index(docs)
self.idf = self._compute_idf(docs)
......
......@@ -5,7 +5,7 @@ import random
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import yaml
import yaml # type: ignore
from tqdm.auto import tqdm
from semantic_router.encoders import BaseEncoder, OpenAIEncoder
......@@ -328,7 +328,7 @@ class RouteLayer:
def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
# create embeddings
embeds = self.encoder(route.utterances)
embeds = self.encoder(route.utterances) # type: ignore
# if route has no score_threshold, use default
if route.score_threshold is None:
route.score_threshold = self.score_threshold
......@@ -337,7 +337,7 @@ class RouteLayer:
self.index.add(
embeddings=embeds,
routes=[route.name] * len(route.utterances),
utterances=route.utterances,
utterances=route.utterances, # type: ignore
)
self.routes.append(route)
......@@ -383,14 +383,14 @@ class RouteLayer:
all_utterances = [
utterance for route in routes for utterance in route.utterances
]
embedded_utterances = self.encoder(all_utterances)
embedded_utterances = self.encoder(all_utterances) # type: ignore
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
utterances=all_utterances, # type: ignore
)
def _encode(self, text: str) -> Any:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment