From ee0d792b543733afb3e8005eb99728eabb6e2712 Mon Sep 17 00:00:00 2001 From: Juan Pablo Mesa Lopez <mesax1@gmail.com> Date: Fri, 26 Apr 2024 00:57:30 -0500 Subject: [PATCH] Added fix to _encode_documents within rolling_window.py when len(docs)> 2048 openai limit --- semantic_router/splitters/rolling_window.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py index a2809ff5..2e02253d 100644 --- a/semantic_router/splitters/rolling_window.py +++ b/semantic_router/splitters/rolling_window.py @@ -100,12 +100,19 @@ class RollingWindowSplitter(BaseSplitter): return splits def _encode_documents(self, docs: List[str]) -> np.ndarray: - try: - embeddings = self.encoder(docs) - return np.array(embeddings) - except Exception as e: - logger.error(f"Error encoding documents {docs}: {e}") - raise + max_docs_per_batch = 2000 # OpenAI limit is 2048 + embeddings = [] + + for i in range(0, len(docs), max_docs_per_batch): + batch_docs = docs[i : i + max_docs_per_batch] + try: + batch_embeddings = self.encoder(batch_docs) + embeddings.extend(batch_embeddings) + except Exception as e: + logger.error(f"Error encoding documents {batch_docs}: {e}") + raise + + return np.array(embeddings) def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]: raw_similarities = [] -- GitLab