From a5f0c254c0dcd4f15a8243be227e5affea7cff7d Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <issey1455@gmail.com> Date: Wed, 21 Feb 2024 13:54:38 +0500 Subject: [PATCH] set default batch size and update progress bar description --- semantic_router/layer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/semantic_router/layer.py b/semantic_router/layer.py index e7381df0..694f4ebe 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -110,9 +110,9 @@ class LayerConfig: llm_class = getattr(llm_module, llm_data["class"]) # Instantiate the LLM class with the provided model name llm = llm_class(name=llm_data["model"]) - route_data[ - "llm" - ] = llm # Reassign the instantiated llm object back to route_data + route_data["llm"] = ( + llm # Reassign the instantiated llm object back to route_data + ) # Dynamically create the Route object using the remaining route_data route = Route(**route_data) @@ -431,19 +431,19 @@ class RouteLayer: self, X: List[str], y: List[str], - batch_size: int, + batch_size: int = 500, max_iter: int = 500, ): # convert inputs into array Xq: List[List[float]] = [] - for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"): + for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): emb = np.array(self.encoder(X[i : i + batch_size])) Xq.extend(emb) # initial eval (we will iterate from here) best_acc = self._vec_evaluate(Xq=np.array(Xq), y=y) best_thresholds = self.get_thresholds() # begin fit - for _ in (pbar := tqdm(range(max_iter))): + for _ in (pbar := tqdm(range(max_iter), desc="Training")): pbar.set_postfix({"acc": round(best_acc, 2)}) # Find the best score threshold for each route thresholds = threshold_random_search( @@ -461,12 +461,12 @@ class RouteLayer: # update route layer to best thresholds self._update_thresholds(score_thresholds=best_thresholds) - def evaluate(self, X: List[str], y: List[str], batch_size: int) -> float: + def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: """ Evaluate the accuracy of the route selection. """ Xq: List[List[float]] = [] - for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"): + for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): emb = np.array(self.encoder(X[i : i + batch_size])) Xq.extend(emb) -- GitLab