Skip to content
Snippets Groups Projects
Commit a5f0c254 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

set default batch size and update progress bar description

parent eb3245b1
No related branches found
No related tags found
No related merge requests found
......@@ -110,9 +110,9 @@ class LayerConfig:
llm_class = getattr(llm_module, llm_data["class"])
# Instantiate the LLM class with the provided model name
llm = llm_class(name=llm_data["model"])
route_data[
"llm"
] = llm # Reassign the instantiated llm object back to route_data
route_data["llm"] = (
llm # Reassign the instantiated llm object back to route_data
)
# Dynamically create the Route object using the remaining route_data
route = Route(**route_data)
......@@ -431,19 +431,19 @@ class RouteLayer:
self,
X: List[str],
y: List[str],
batch_size: int,
batch_size: int = 500,
max_iter: int = 500,
):
# convert inputs into array
Xq: List[List[float]] = []
for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"):
for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
emb = np.array(self.encoder(X[i : i + batch_size]))
Xq.extend(emb)
# initial eval (we will iterate from here)
best_acc = self._vec_evaluate(Xq=np.array(Xq), y=y)
best_thresholds = self.get_thresholds()
# begin fit
for _ in (pbar := tqdm(range(max_iter))):
for _ in (pbar := tqdm(range(max_iter), desc="Training")):
pbar.set_postfix({"acc": round(best_acc, 2)})
# Find the best score threshold for each route
thresholds = threshold_random_search(
......@@ -461,12 +461,12 @@ class RouteLayer:
# update route layer to best thresholds
self._update_thresholds(score_thresholds=best_thresholds)
def evaluate(self, X: List[str], y: List[str], batch_size: int) -> float:
def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float:
"""
Evaluate the accuracy of the route selection.
"""
Xq: List[List[float]] = []
for i in tqdm(range(0, len(X), batch_size), desc="Processing batches"):
for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
emb = np.array(self.encoder(X[i : i + batch_size]))
Xq.extend(emb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment