Skip to content
Snippets Groups Projects
Unverified Commit 46789a42 authored by James Briggs's avatar James Briggs
Browse files

lint

parent ef7eaa6b
No related branches found
No related tags found
No related merge requests found
import json import json
import os import os
import random
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import yaml import yaml
from tqdm.auto import tqdm
from semantic_router.encoders import BaseEncoder, OpenAIEncoder from semantic_router.encoders import BaseEncoder, OpenAIEncoder
from semantic_router.linear import similarity_matrix, top_scores from semantic_router.linear import similarity_matrix, top_scores
...@@ -11,8 +13,6 @@ from semantic_router.llms import BaseLLM, OpenAILLM ...@@ -11,8 +13,6 @@ from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route from semantic_router.route import Route
from semantic_router.schema import Encoder, EncoderType, RouteChoice from semantic_router.schema import Encoder, EncoderType, RouteChoice
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
from tqdm.auto import tqdm
import random
def is_valid(layer_config: str) -> bool: def is_valid(layer_config: str) -> bool:
...@@ -311,16 +311,14 @@ class RouteLayer: ...@@ -311,16 +311,14 @@ class RouteLayer:
) )
def _encode(self, text: str) -> np.ndarray: def _encode(self, text: str) -> np.ndarray:
"""Given some text, encode it. """Given some text, encode it."""
"""
# create query vector # create query vector
xq = np.array(self.encoder([text])) xq = np.array(self.encoder([text]))
xq = np.squeeze(xq) # Reduce to 1d array. xq = np.squeeze(xq) # Reduce to 1d array.
return xq return xq
def _retrieve(self, xq: np.ndarray, top_k: int = 5) -> List[dict]: def _retrieve(self, xq: np.ndarray, top_k: int = 5) -> List[dict]:
"""Given a query vector, retrieve the top_k most similar records. """Given a query vector, retrieve the top_k most similar records."""
"""
if self.index is not None: if self.index is not None:
# calculate similarity matrix # calculate similarity matrix
sim = similarity_matrix(xq, self.index) sim = similarity_matrix(xq, self.index)
...@@ -358,11 +356,8 @@ class RouteLayer: ...@@ -358,11 +356,8 @@ class RouteLayer:
return max(scores) > threshold return max(scores) > threshold
else: else:
return False return False
def _update_thresholds( def _update_thresholds(self, score_thresholds: Optional[Dict[str, float]] = None):
self,
score_thresholds: Optional[Dict[str, float]] = None
):
""" """
Update the score thresholds for each route. Update the score thresholds for each route.
""" """
...@@ -420,7 +415,7 @@ class RouteLayer: ...@@ -420,7 +415,7 @@ class RouteLayer:
best_thresholds = thresholds best_thresholds = thresholds
# update route layer to best thresholds # update route layer to best thresholds
self._update_thresholds(score_thresholds=best_thresholds) self._update_thresholds(score_thresholds=best_thresholds)
def evaluate(self, X: List[Union[str, float]], y: List[str]) -> float: def evaluate(self, X: List[Union[str, float]], y: List[str]) -> float:
""" """
Evaluate the accuracy of the route selection. Evaluate the accuracy of the route selection.
...@@ -440,8 +435,7 @@ def threshold_random_search( ...@@ -440,8 +435,7 @@ def threshold_random_search(
route_layer: RouteLayer, route_layer: RouteLayer,
search_range: Union[int, float], search_range: Union[int, float],
) -> Tuple[float, Dict[str, float]]: ) -> Tuple[float, Dict[str, float]]:
"""Performs a random search iteration given a route layer and a search range. """Performs a random search iteration given a route layer and a search range."""
"""
# extract the route names # extract the route names
routes = route_layer.get_route_thresholds() routes = route_layer.get_route_thresholds()
route_names = list(routes.keys()) route_names = list(routes.keys())
...@@ -451,13 +445,14 @@ def threshold_random_search( ...@@ -451,13 +445,14 @@ def threshold_random_search(
for threshold in route_thresholds: for threshold in route_thresholds:
score_threshold_values.append( score_threshold_values.append(
np.linspace( np.linspace(
start=max(threshold-search_range, 0.0), start=max(threshold - search_range, 0.0),
stop=min(threshold+search_range, 1.0), stop=min(threshold + search_range, 1.0),
num=100 num=100,
) )
) )
# Generate a random threshold for each route # Generate a random threshold for each route
score_thresholds = { score_thresholds = {
route: random.choice(score_threshold_values[i]) for i, route in enumerate(route_names) route: random.choice(score_threshold_values[i])
for i, route in enumerate(route_names)
} }
return score_thresholds return score_thresholds
...@@ -10,7 +10,7 @@ from semantic_router.encoders import ( ...@@ -10,7 +10,7 @@ from semantic_router.encoders import (
FastEmbedEncoder, FastEmbedEncoder,
OpenAIEncoder, OpenAIEncoder,
) )
from semantic_router.utils.splitters import semantic_splitter, DocumentSplit from semantic_router.utils.splitters import DocumentSplit, semantic_splitter
class EncoderType(Enum): class EncoderType(Enum):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment