Skip to content
Snippets Groups Projects
Unverified Commit eb56039a authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Removed _simulate_static_route_selection

Simulation of static routes from dynamic routes now handled in __call__ as this avoids a situation where we have subtly different logic in __call__ compared to _simulate_static_route_selection.
parent 5eca3c02
No related branches found
No related tags found
No related merge requests found
......@@ -220,6 +220,7 @@ class RouteLayer:
self,
text: Optional[str] = None,
vector: Optional[List[float]] = None,
simulate_static: bool = False,
) -> RouteChoice:
# if no vector provided, encode text to get vector
if vector is None:
......@@ -230,7 +231,7 @@ class RouteLayer:
route, top_class_scores = self._retrieve_top_route(vector)
passed = self._check_threshold(top_class_scores, route)
if passed and route is not None:
if passed and route is not None and not simulate_static:
if route.function_schema and text is None:
raise ValueError(
"Route has a function schema, but no text was provided."
......@@ -248,6 +249,13 @@ class RouteLayer:
else:
route.llm = self.llm
return route(text)
elif passed and route is not None and simulate_static:
return RouteChoice(
name=route.name,
function_call=None,
similarity_score=None,
trigger=None,
)
else:
# if no route passes threshold, return empty route choice
return RouteChoice()
......@@ -497,28 +505,13 @@ class RouteLayer:
"""
correct = 0
for xq, target_route in zip(Xq, y):
# We can't do route_choice = self(vector=xq) here as it won't work for dynamic routes.
route_choice = self._simulate_static_route_selection(vector=xq)
# We treate dynamic routes as static here, because when evaluating we use only vectors, and dynamic routes expect strings by default.
route_choice = self(vector=xq, simulate_static=True)
if route_choice.name == target_route:
correct += 1
accuracy = correct / len(Xq)
return accuracy
def _simulate_static_route_selection(self, vector: List[float]) -> RouteChoice:
"""
Simulate the route selection process treating all routes as static, including threshold checking.
Dynamic routes require a query string to be passed to the __call__ method, but here we work with vectors to boost performance.
Hence, we simulate the route selection process treating all routes as static.
"""
route, scores = self._retrieve_top_route(vector)
passed = self._check_threshold(scores, route)
if passed and route is not None:
return RouteChoice(
name=route.name, function_call=None, similarity_score=None, trigger=None
)
else:
return RouteChoice()
def _get_route_names(self) -> List[str]:
return [route.name for route in self.routes]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment