From a5f0c254c0dcd4f15a8243be227e5affea7cff7d Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <issey1455@gmail.com>
Date: Wed, 21 Feb 2024 13:54:38 +0500
Subject: [PATCH] set default batch size and update progress bar description

---
 semantic_router/layer.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index e7381df0..694f4ebe 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -110,9 +110,9 @@ class LayerConfig:
                     llm_class = getattr(llm_module, llm_data["class"])
                     # Instantiate the LLM class with the provided model name
                     llm = llm_class(name=llm_data["model"])
-                    route_data[
-                        "llm"
-                    ] = llm  # Reassign the instantiated llm object back to route_data
+                    route_data["llm"] = (
+                        llm  # Reassign the instantiated llm object back to route_data
+                    )
 
                 # Dynamically create the Route object using the remaining route_data
                 route = Route(**route_data)
@@ -431,19 +431,19 @@ class RouteLayer:
         self,
         X: List[str],
         y: List[str],
-        batch_size: int,
+        batch_size: int = 500,
         max_iter: int = 500,
     ):
         # convert inputs into array
         Xq: List[List[float]] = []
-        for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"):
+        for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
             emb = np.array(self.encoder(X[i : i + batch_size]))
             Xq.extend(emb)
         # initial eval (we will iterate from here)
         best_acc = self._vec_evaluate(Xq=np.array(Xq), y=y)
         best_thresholds = self.get_thresholds()
         # begin fit
-        for _ in (pbar := tqdm(range(max_iter))):
+        for _ in (pbar := tqdm(range(max_iter), desc="Training")):
             pbar.set_postfix({"acc": round(best_acc, 2)})
             # Find the best score threshold for each route
             thresholds = threshold_random_search(
@@ -461,12 +461,12 @@ class RouteLayer:
         # update route layer to best thresholds
         self._update_thresholds(score_thresholds=best_thresholds)
 
-    def evaluate(self, X: List[str], y: List[str], batch_size: int) -> float:
+    def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float:
         """
         Evaluate the accuracy of the route selection.
         """
         Xq: List[List[float]] = []
-        for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"):
+        for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
             emb = np.array(self.encoder(X[i : i + batch_size]))
             Xq.extend(emb)
 
-- 
GitLab