Skip to content
Snippets Groups Projects
Unverified Commit 43e662df authored by Zahid Syed's avatar Zahid Syed Committed by GitHub
Browse files

Merge branch 'main' into zahid/issue_fixes

parents 7ac33fa5 ea2260ee
No related branches found
No related tags found
No related merge requests found
...@@ -220,29 +220,18 @@ class RouteLayer: ...@@ -220,29 +220,18 @@ class RouteLayer:
self, self,
text: Optional[str] = None, text: Optional[str] = None,
vector: Optional[List[float]] = None, vector: Optional[List[float]] = None,
simulate_static: bool = False,
) -> RouteChoice: ) -> RouteChoice:
# if no vector provided, encode text to get vector # if no vector provided, encode text to get vector
if vector is None: if vector is None:
if text is None: if text is None:
raise ValueError("Either text or vector must be provided") raise ValueError("Either text or vector must be provided")
vector_arr = self._encode(text=text) vector = self._encode(text=text)
else:
vector_arr = np.array(vector) route, top_class_scores = self._retrieve_top_route(vector)
# get relevant results (scores and routes) passed = self._check_threshold(top_class_scores, route)
results = self._retrieve(xq=vector_arr)
# decide most relevant routes if passed and route is not None and not simulate_static:
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
if route is None:
return RouteChoice()
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
passed = self._pass_threshold(top_class_scores, threshold)
if passed:
if route.function_schema and text is None: if route.function_schema and text is None:
raise ValueError( raise ValueError(
"Route has a function schema, but no text was provided." "Route has a function schema, but no text was provided."
...@@ -260,10 +249,45 @@ class RouteLayer: ...@@ -260,10 +249,45 @@ class RouteLayer:
else: else:
route.llm = self.llm route.llm = self.llm
return route(text) return route(text)
elif passed and route is not None and simulate_static:
return RouteChoice(
name=route.name,
function_call=None,
similarity_score=None,
trigger=None,
)
else: else:
# if no route passes threshold, return empty route choice # if no route passes threshold, return empty route choice
return RouteChoice() return RouteChoice()
def _retrieve_top_route(
self, vector: List[float]
) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
"""
# get relevant results (scores and routes)
results = self._retrieve(xq=np.array(vector))
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores
def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
"""
Check if the route's score passes the specified threshold.
"""
if route is None:
return False
threshold = (
route.score_threshold
if route.score_threshold is not None
else self.score_threshold
)
return self._pass_threshold(scores, threshold)
def __str__(self): def __str__(self):
return ( return (
f"RouteLayer(encoder={self.encoder}, " f"RouteLayer(encoder={self.encoder}, "
...@@ -481,7 +505,8 @@ class RouteLayer: ...@@ -481,7 +505,8 @@ class RouteLayer:
""" """
correct = 0 correct = 0
for xq, target_route in zip(Xq, y): for xq, target_route in zip(Xq, y):
route_choice = self(vector=xq) # We treate dynamic routes as static here, because when evaluating we use only vectors, and dynamic routes expect strings by default.
route_choice = self(vector=xq, simulate_static=True)
if route_choice.name == target_route: if route_choice.name == target_route:
correct += 1 correct += 1
accuracy = correct / len(Xq) accuracy = correct / len(Xq)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment