From 78857b5456ad7af9f9b835fcc8f56b650c5f7d3e Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Sat, 27 Apr 2024 18:24:12 +0800
Subject: [PATCH] fix for pytests

---
 semantic_router/index/pinecone.py | 37 ++++++++++++++++++++++++++-----
 semantic_router/layer.py          |  8 +++----
 tests/unit/test_layer.py          | 26 ++++++++++++++++++++++
 3 files changed, 62 insertions(+), 9 deletions(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index e240ed31..148c9070 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -49,12 +49,27 @@ class PineconeIndex(BaseIndex):
     ServerlessSpec: Any = Field(default=None, exclude=True)
     namespace: Optional[str] = ""
 
-    def __init__(self, **data):
-        super().__init__(**data)
-        self._initialize_client()
+    def __init__(
+        self,
+        api_key: Optional[str] = None,
+        index_name: str = "index",
+        dimensions: Optional[int] = None,
+        metric: str = "cosine",
+        cloud: str = "aws",
+        region: str = "us-west-2",
+        host: str = "",
+        namespace: Optional[str] = "",
+    ):
+        super().__init__()
+        self.index_name = index_name
+        self.dimensions = dimensions
+        self.metric = metric
+        self.cloud = cloud
+        self.region = region
+        self.host = host
+        self.namespace = namespace
         self.type = "pinecone"
-        self.client = self._initialize_client()
-        self.index = self._init_index(force_create=True)
+        self.client = self._initialize_client(api_key=api_key)
 
     def _initialize_client(self, api_key: Optional[str] = None):
         try:
@@ -77,6 +92,18 @@ class PineconeIndex(BaseIndex):
         return Pinecone(**pinecone_args)
 
     def _init_index(self, force_create: bool = False) -> Union[Any, None]:
+        """Initializing the index can be done after the object has been created
+        to allow for the user to set the dimensions and other parameters.
+
+        If the index doesn't exist and the dimensions are given, the index will
+        be created. If the index exists, it will be returned. If the index doesn't
+        exist and the dimensions are not given, the index will not be created and
+        None will be returned.
+
+        :param force_create: If True, the index will be created even if the
+            dimensions are not given (which will raise an error).
+        :type force_create: bool, optional
+        """
         index_exists = self.index_name in self.client.list_indexes().names()
         dimensions_given = self.dimensions is not None
         if dimensions_given and not index_exists:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index a138893a..d9781820 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -354,7 +354,7 @@ class RouteLayer:
     def add(self, route: Route):
         logger.info(f"Adding `{route.name}` route")
         # create embeddings
-        embeds = self.encoder(route.utterances)  # type: ignore
+        embeds = self.encoder(route.utterances)
         # if route has no score_threshold, use default
         if route.score_threshold is None:
             route.score_threshold = self.score_threshold
@@ -363,7 +363,7 @@ class RouteLayer:
         self.index.add(
             embeddings=embeds,
             routes=[route.name] * len(route.utterances),
-            utterances=route.utterances,  # type: ignore
+            utterances=route.utterances,
         )
         self.routes.append(route)
 
@@ -409,14 +409,14 @@ class RouteLayer:
         all_utterances = [
             utterance for route in routes for utterance in route.utterances
         ]
-        embedded_utterances = self.encoder(all_utterances)  # type: ignore
+        embedded_utterances = self.encoder(all_utterances)
         # create route array
         route_names = [route.name for route in routes for _ in route.utterances]
         # add everything to the index
         self.index.add(
             embeddings=embedded_utterances,
             routes=route_names,
-            utterances=all_utterances,  # type: ignore
+            utterances=all_utterances,
         )
 
     def _encode(self, text: str) -> Any:
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 8f4833f0..fb5a1439 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -4,6 +4,7 @@ import tempfile
 from unittest.mock import mock_open, patch
 
 import pytest
+import time
 
 from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
 from semantic_router.index.local import LocalIndex
@@ -279,12 +280,37 @@ class TestRouteLayer:
         route_layer = RouteLayer(
             encoder=openai_encoder, routes=routes, index=pineconeindex
         )
+        time.sleep(5)  # allow for index to be populated
+        print(routes)
         query_result = route_layer(text="Hello", route_filter=["Route 1"]).name
+        print(query_result)
 
         try:
             route_layer(text="Hello", route_filter=["Route 8"]).name
         except ValueError:
             assert True
+        
+        # delete index
+        pineconeindex.delete_index()
+
+        assert query_result in ["Route 1"]
+
+    def test_namespace_pinecone_index(self, openai_encoder, routes, index_cls):
+        pinecone_api_key = os.environ["PINECONE_API_KEY"]
+        pineconeindex = PineconeIndex(api_key=pinecone_api_key, namespace="test")
+        route_layer = RouteLayer(
+            encoder=openai_encoder, routes=routes, index=pineconeindex
+        )
+        time.sleep(5)  # allow for index to be populated
+        query_result = route_layer(text="Hello", route_filter=["Route 1"]).name
+
+        try:
+            route_layer(text="Hello", route_filter=["Route 8"]).name
+        except ValueError:
+            assert True
+
+        # delete index
+        pineconeindex.delete_index()
 
         assert query_result in ["Route 1"]
 
-- 
GitLab