From f269a24e9e6548a7a2d3a6c6faab50491da0bcf2 Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <ismailashraq@Ismails-MacBook-Pro.local>
Date: Sun, 15 Dec 2024 17:54:24 +0800
Subject: [PATCH] fix vector shape for single utterance

---
 semantic_router/routers/base.py     | 4 ++--
 semantic_router/routers/semantic.py | 2 --
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py
index 328cf2b7..18d71ced 100644
--- a/semantic_router/routers/base.py
+++ b/semantic_router/routers/base.py
@@ -522,7 +522,7 @@ class BaseRouter(BaseModel):
         """
         # get relevant results (scores and routes)
         results = self._retrieve(
-            xq=np.array(vector), top_k=self.top_k, route_filter=route_filter
+            xq=vector[0], top_k=self.top_k, route_filter=route_filter
         )
         # decide most relevant routes
         top_class, top_class_scores = self._semantic_classify(results)
@@ -535,7 +535,7 @@ class BaseRouter(BaseModel):
     ) -> Tuple[Optional[Route], List[float]]:
         # get relevant results (scores and routes)
         results = await self._async_retrieve(
-            xq=np.array(vector), top_k=self.top_k, route_filter=route_filter
+            xq=vector[0], top_k=self.top_k, route_filter=route_filter
         )
         # decide most relevant routes
         top_class, top_class_scores = await self._async_semantic_classify(results)
diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py
index 33af2a32..41c92d53 100644
--- a/semantic_router/routers/semantic.py
+++ b/semantic_router/routers/semantic.py
@@ -40,14 +40,12 @@ class SemanticRouter(BaseRouter):
         """Given some text, encode it."""
         # create query vector
         xq = np.array(self.encoder(text))
-        xq = np.squeeze(xq)  # Reduce to 1d array.
         return xq
 
     async def _async_encode(self, text: list[str]) -> Any:
         """Given some text, encode it."""
         # create query vector
         xq = np.array(await self.encoder.acall(docs=text))
-        xq = np.squeeze(xq)  # Reduce to 1d array.
         return xq
 
     def add(self, routes: List[Route] | Route):
-- 
GitLab