Skip to content
Snippets Groups Projects
Commit 2247b9b0 authored by zahid-syed's avatar zahid-syed
Browse files

Add batches to layer.py for semantic router

parent a3732f1f
No related branches found
No related tags found
No related merge requests found
...@@ -431,12 +431,17 @@ class RouteLayer: ...@@ -431,12 +431,17 @@ class RouteLayer:
self, self,
X: List[str], X: List[str],
y: List[str], y: List[str],
batch_size : int,
max_iter: int = 500, max_iter: int = 500,
): ):
# convert inputs into array # convert inputs into array
Xq: Any = np.array(self.encoder(X)) Xq = []
for i in tqdm(range(0, len(X), batch_size), desc= "Processing batches"):
emb = np.array(self.encoder(X[i:i+batch_size]))
Xq.extend(emb)
# initial eval (we will iterate from here) # initial eval (we will iterate from here)
best_acc = self._vec_evaluate(Xq=Xq, y=y) best_acc = self._vec_evaluate(Xq=np.array(Xq), y=y)
best_thresholds = self.get_thresholds() best_thresholds = self.get_thresholds()
# begin fit # begin fit
for _ in (pbar := tqdm(range(max_iter))): for _ in (pbar := tqdm(range(max_iter))):
...@@ -457,12 +462,16 @@ class RouteLayer: ...@@ -457,12 +462,16 @@ class RouteLayer:
# update route layer to best thresholds # update route layer to best thresholds
self._update_thresholds(score_thresholds=best_thresholds) self._update_thresholds(score_thresholds=best_thresholds)
def evaluate(self, X: List[str], y: List[str]) -> float: def evaluate(self, X: List[str], y: List[str], batch_size: int) -> float:
""" """
Evaluate the accuracy of the route selection. Evaluate the accuracy of the route selection.
""" """
Xq = np.array(self.encoder(X)) Xq = []
accuracy = self._vec_evaluate(Xq=Xq, y=y) for i in tqdm(range(0,len(X),batch_size), desc="Processing batches"):
emb = np.array(self.encoder(X[i:i+batch_size]))
Xq.extend(emb)
accuracy = self._vec_evaluate(Xq=np.array(Xq), y=y)
return accuracy return accuracy
def _vec_evaluate(self, Xq: Union[List[float], Any], y: List[str]) -> float: def _vec_evaluate(self, Xq: Union[List[float], Any], y: List[str]) -> float:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment