diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a3e9023c19904291817b73f2b47b14c328b5d0ac..f94862c0d44dd19a6c431c3ea5f15febb5e0179d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,11 +11,11 @@ Please follow these guidelines when making a contribution: - Ensure that the Pull Request title is prepended with a [valid type](https://flank.github.io/flank/pr_titles/). E.g. `feat: My New Feature`. - Run linting (and fix any issues that are flagged) by: - Navigating to /semantic-router. - - Running `poetry run make lint` to fix linting issues. - - Running `poetry run black .` to fix `black` linting issues. - - Running `poetry run ruff . --fix` to fix `ruff` linting issues (where possible, others may need manual changes). - - Confirming the linters pass using `poetry run make lint` again. - - Running `ruff . --fix`. + - Running `make lint` to fix linting issues. + - Running `black .` to fix `black` linting issues. + - Running `ruff . --fix` to fix `ruff` linting issues (where possible, others may need manual changes). + - Running `mypy .` and then fixing any of the issues that are raised. + - Confirming the linters pass using `make lint` again. - Ensure that, for any new code, new [PyTests are written](https://github.com/aurelio-labs/semantic-router/tree/main/tests/unit). If any code is removed, then ensure that corresponding PyTests are also removed. Finally, ensure that all remaining PyTests pass using `pytest ./tests` (to avoid integration tests you can run `pytest ./tests/unit`. - Codecov checks will inform you if any code is not covered by PyTests upon creating the PR. You should aim to cover new code with PyTests. diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 640a68a7034db84131f1768dde02204fedf38277..186b32d0e8b829aced4fa1e2caeb7d75bf048b14 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -431,15 +431,19 @@ class RouteLayer: self, X: List[str], y: List[str], + batch_size: int = 500, max_iter: int = 500, ): # convert inputs into array - Xq: Any = np.array(self.encoder(X)) + Xq: List[List[float]] = [] + for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): + emb = np.array(self.encoder(X[i : i + batch_size])) + Xq.extend(emb) # initial eval (we will iterate from here) - best_acc = self._vec_evaluate(Xq=Xq, y=y) + best_acc = self._vec_evaluate(Xq=np.array(Xq), y=y) best_thresholds = self.get_thresholds() # begin fit - for _ in (pbar := tqdm(range(max_iter))): + for _ in (pbar := tqdm(range(max_iter), desc="Training")): pbar.set_postfix({"acc": round(best_acc, 2)}) # Find the best score threshold for each route thresholds = threshold_random_search( @@ -457,12 +461,16 @@ class RouteLayer: # update route layer to best thresholds self._update_thresholds(score_thresholds=best_thresholds) - def evaluate(self, X: List[str], y: List[str]) -> float: + def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: """ Evaluate the accuracy of the route selection. """ - Xq = np.array(self.encoder(X)) - accuracy = self._vec_evaluate(Xq=Xq, y=y) + Xq: List[List[float]] = [] + for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): + emb = np.array(self.encoder(X[i : i + batch_size])) + Xq.extend(emb) + + accuracy = self._vec_evaluate(Xq=np.array(Xq), y=y) return accuracy def _vec_evaluate(self, Xq: Union[List[float], Any], y: List[str]) -> float: diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 00bad4ff5f32dda8c0ec748b2f16388ef7e3e101..3f2c413f45f279656bada2ffdc7ee47b96ef1a6c 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -432,13 +432,13 @@ class TestLayerFit: # unpack test data X, y = zip(*test_data) # evaluate - route_layer.evaluate(X=X, y=y) + route_layer.evaluate(X=X, y=y, batch_size=int(len(test_data) / 5)) def test_fit(self, openai_encoder, routes, test_data): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) # unpack test data X, y = zip(*test_data) - route_layer.fit(X=X, y=y) + route_layer.fit(X=X, y=y, batch_size=int(len(test_data) / 5)) # Add more tests for edge cases and error handling as needed.