From 2745d3e0694310218af2369af7219beec7c70ec4 Mon Sep 17 00:00:00 2001
From: Siraj R Aizlewood <siraj@aurelio.ai>
Date: Thu, 8 Feb 2024 01:09:11 +0400
Subject: [PATCH] Tidying up and Bug Fixing

---
 poetry.lock                         |  4 +--
 pyproject.toml                      |  4 +--
 semantic_router/indices/pinecone.py | 40 +++++++++++++++++++----------
 semantic_router/layer.py            |  7 ++---
 semantic_router/schema.py           | 18 +------------
 5 files changed, 36 insertions(+), 37 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index 15054e5d..d5e1153b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3403,5 +3403,5 @@ pinecone = ["pinecone-client"]
 
 [metadata]
 lock-version = "2.0"
-python-versions = "^3.9"
-content-hash = "10a0117bd6c131c255db119d7c83ace94f09dbe77ac6a4bba7bf0d90da54fad0"
+python-versions = ">=3.9,<3.13"
+content-hash = "52ce34492a7d4827a3c2b96332e7285369209dbe9ec2a9488d8eac2c13d4d0c6"
diff --git a/pyproject.toml b/pyproject.toml
index a945ac65..1491b6d7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,7 +15,7 @@ readme = "README.md"
 packages = [{include = "semantic_router"}]
 
 [tool.poetry.dependencies]
-python = "^3.9"
+python = ">=3.9,<3.13"
 pydantic = "^2.5.3"
 openai = "^1.10.0"
 cohere = "^4.32"
@@ -29,7 +29,7 @@ transformers = {version = "^4.36.2", optional = true}
 llama-cpp-python = {version = "^0.2.28", optional = true}
 black = "^23.12.1"
 colorama = "^0.4.6"
-pinecone-client = {version = "^3.0.0", optional = true, python=">=3.9,<3.13"}
+pinecone-client = {version="^3.0.0", optional = true}
 [tool.poetry.extras]
 hybrid = ["pinecone-text"]
 fastembed = ["fastembed"]
diff --git a/semantic_router/indices/pinecone.py b/semantic_router/indices/pinecone.py
index c67c0c51..443565b6 100644
--- a/semantic_router/indices/pinecone.py
+++ b/semantic_router/indices/pinecone.py
@@ -1,25 +1,39 @@
+from pydantic import BaseModel, Field
 import os
 import pinecone
-import numpy as np
-from typing import List, Tuple
+from typing import Any, List, Tuple
 from semantic_router.indices.base import BaseIndex
-
+import numpy as np
 
 class PineconeIndex(BaseIndex):
-    def __init__(self, index_name: str, environment: str = 'us-west1-gcp', metric: str = 'cosine', dimension: int = 768):
-        super().__init__()
-        
-        # Initialize Pinecone environment
-        pinecone.init(api_key=os.getenv("PINECONE_API_KEY"), environment=environment)
+    index_name: str
+    dimension: int = 768
+    metric: str = "cosine"
+    cloud: str = "aws"
+    region: str = "us-west-2" 
+    pinecone: Any = Field(default=None, exclude=True)
+
+    def __init__(self, **data):
+        super().__init__(**data)
+        # Initialize Pinecone environment with the new API
+        self.pinecone = pinecone.Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
         
         # Create or connect to an existing Pinecone index
-        if index_name not in pinecone.list_indexes():
-            print(f"Creating new Pinecone index: {index_name}")
-            pinecone.create_index(name=index_name, metric=metric, dimension=dimension)
-        self.index = pinecone.Index(index_name)
+        if self.index_name not in self.pinecone.list_indexes().names():
+            print(f"Creating new Pinecone index: {self.index_name}")
+            self.pinecone.create_index(
+                name=self.index_name, 
+                dimension=self.dimension, 
+                metric=self.metric,
+                spec=pinecone.ServerlessSpec(
+                    cloud=self.cloud,
+                    region=self.region
+                )
+            )
+        self.index = self.pinecone.Index(self.index_name)
         
         # Store the index name for potential deletion
-        self.index_name = index_name
+        self.index_name = self.index_name
 
     def add(self, embeds: List[np.ndarray]):
         # Assuming embeds is a list of tuples (id, vector)
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 9794e2e2..cff166c3 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -10,9 +10,10 @@ from tqdm.auto import tqdm
 from semantic_router.encoders import BaseEncoder, OpenAIEncoder
 from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
-from semantic_router.schema import Encoder, EncoderType, RouteChoice, Index
+from semantic_router.schema import Encoder, EncoderType, RouteChoice
 from semantic_router.utils.logger import logger
 from semantic_router.indices.base import BaseIndex
+from semantic_router.indices.local_index import LocalIndex
 
 
 def is_valid(layer_config: str) -> bool:
@@ -161,10 +162,10 @@ class RouteLayer:
         encoder: Optional[BaseEncoder] = None,
         llm: Optional[BaseLLM] = None,
         routes: Optional[List[Route]] = None,
-        index_name: Optional[str] = "local",
+        index: Optional[BaseIndex] = LocalIndex,
     ):
         logger.info("local")
-        self.index: BaseIndex = Index.get_by_name(index_name=index_name)
+        self.index: BaseIndex = index
         self.categories = None
         if encoder is None:
             logger.warning(
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index df66ecd5..61f9b0b6 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -13,7 +13,7 @@ from semantic_router.encoders import (
 
 from semantic_router.indices.local_index import LocalIndex
 from semantic_router.indices.pinecone import PineconeIndex
-
+from semantic_router.indices.base import BaseIndex
 
 class EncoderType(Enum):
     HUGGINGFACE = "huggingface"
@@ -76,19 +76,3 @@ class DocumentSplit(BaseModel):
     docs: List[str]
     is_triggered: bool = False
     triggered_score: Optional[float] = None
-
-
-class Index:
-    index_map = {
-        "local": LocalIndex,
-        "pinecone": PineconeIndex,
-    }
-
-    @classmethod
-    def get_by_name(cls, index_name: Optional[str] = None):
-        if index_name is None:
-            index_name = "local"
-        try:
-            return cls.index_map[index_name]()
-        except KeyError:
-            raise ValueError(f"Invalid index name: {index_name}")
-- 
GitLab