From 3533b8a10094d853c03c23cdd00b64bde6d0ce54 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Thu, 13 Jun 2024 01:40:13 +0800
Subject: [PATCH] feat: methods for async pinecone client

---
 semantic_router/index/pinecone.py | 135 +++++++++++++++++++++++++++++-
 1 file changed, 134 insertions(+), 1 deletion(-)

diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index de288524..04cfc796 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -1,3 +1,5 @@
+import aiohttp
+import asyncio
 import hashlib
 import os
 import time
@@ -46,9 +48,11 @@ class PineconeIndex(BaseIndex):
     region: str = "us-west-2"
     host: str = ""
     client: Any = Field(default=None, exclude=True)
+    async_client: Any = Field(default=None, exclude=True)
     index: Optional[Any] = Field(default=None, exclude=True)
     ServerlessSpec: Any = Field(default=None, exclude=True)
     namespace: Optional[str] = ""
+    base_url: Optional[str] = "https://api.pinecone.io"
 
     def __init__(
         self,
@@ -60,6 +64,7 @@ class PineconeIndex(BaseIndex):
         region: str = "us-west-2",
         host: str = "",
         namespace: Optional[str] = "",
+        base_url: Optional[str] = "https://api.pinecone.io",
     ):
         super().__init__()
         self.index_name = index_name
@@ -71,11 +76,13 @@ class PineconeIndex(BaseIndex):
         self.namespace = namespace
         self.type = "pinecone"
         self.api_key = api_key or os.getenv("PINECONE_API_KEY")
+        self.base_url = base_url
 
         if self.api_key is None:
             raise ValueError("Pinecone API key is required.")
 
         self.client = self._initialize_client(api_key=self.api_key)
+        self.async_client = self._initialize_async_client(api_key=self.api_key)
 
     def _initialize_client(self, api_key: Optional[str] = None):
         try:
@@ -88,11 +95,22 @@ class PineconeIndex(BaseIndex):
                 "You can install it with: "
                 "`pip install 'semantic-router[pinecone]'`"
             )
-        pinecone_args = {"api_key": api_key, "source_tag": "semantic-router"}
+        pinecone_args = {"api_key": api_key, "source_tag": "semanticrouter"}
         if self.namespace:
             pinecone_args["namespace"] = self.namespace
 
         return Pinecone(**pinecone_args)
+    
+    def _initialize_async_client(self, api_key: Optional[str] = None):
+        async_client = aiohttp.ClientSession(
+            headers={
+                "Api-Key": api_key,
+                "Content-Type": "application/json",
+                "X-Pinecone-API-Version": "2024-07",
+                "User-Agent": "source_tag=semanticrouter"
+            }
+        )
+        return async_client
 
     def _init_index(self, force_create: bool = False) -> Union[Any, None]:
         """Initializing the index can be done after the object has been created
@@ -140,6 +158,40 @@ class PineconeIndex(BaseIndex):
         if index is not None:
             self.host = self.client.describe_index(self.index_name)["host"]
         return index
+    
+    async def _init_async_index(self, force_create: bool = False) -> Union[Any, None]:
+        index_stats = None
+        indexes = await self._async_list_indexes()
+        index_names = [i["name"] for i in indexes["indexes"]]
+        index_exists = self.index_name in index_names
+        dimensions_given = self.dimensions is not None
+        if dimensions_given and not index_exists:
+            await self._async_create_index(
+                name=self.index_name,
+                dimension=self.dimensions,
+                metric=self.metric,
+                cloud=self.cloud,
+                region=self.region
+            )
+            # TODO describe index and async sleep
+            index_ready = "false"
+            while index_ready != "true":
+                index_stats = await self._async_describe_index(self.index_name)
+                index_ready = index_stats["status"]["ready"]
+                await asyncio.sleep(1)
+        elif index_exists:
+            index_stats = await self._async_describe_index(self.index_name)
+            # grab dimensions for the index
+            self.dimensions = index_stats["dimension"]
+        elif force_create and not dimensions_given:
+            raise ValueError(
+                "Cannot create an index without specifying the dimensions."
+            )
+        else:
+            # if the index doesn't exist and we don't have the dimensions
+            # we raise warning
+            logger.warning("Index could not be initialized.")
+        self.host = index_stats["host"] if index_stats else None
 
     def _batch_upsert(self, batch: List[Dict]):
         """Helper method for upserting a single batch of records."""
@@ -280,9 +332,90 @@ class PineconeIndex(BaseIndex):
         scores = [result["score"] for result in results["matches"]]
         route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
         return np.array(scores), route_names
+    
+    async def aquery(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
+        if self.async_client is None or self.host is None:
+            raise ValueError("Async client or host are not initialized.")
+        query_vector_list = vector.tolist()
+        if route_filter is not None:
+            filter_query = {"sr_route": {"$in": route_filter}}
+        else:
+            filter_query = None
+        results = await self._async_query(
+            vector=query_vector_list,
+            namespace=self.namespace,
+            filter=filter_query,
+            top_k=top_k,
+            include_metadata=True,
+        )
+        scores = [result["score"] for result in results["matches"]]
+        route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
+        return np.array(scores), route_names
+        
 
     def delete_index(self):
         self.client.delete_index(self.index_name)
 
+    # __ASYNC CLIENT METHODS__
+    async def _async_query(
+        self,
+        vector: list[float],
+        namespace: str = "",
+        filter: Optional[dict] = None,
+        top_k: int = 5,
+        include_metadata: bool = False,
+    ):
+        params = {
+            "vector": vector,
+            "namespace": namespace,
+            "filter": filter,
+            "top_k": top_k,
+            "include_metadata": include_metadata,
+        }
+        async with self.async_client.post(
+            f"https://{self.host}/query",
+            json=params,
+        ) as response:
+            return await response.json(content_type=None)
+
+    async def _async_list_indexes(self):
+        async with self.async_client.get(f"{self.base_url}/indexes") as response:
+            return await response.json(content_type=None)
+    
+    async def _async_create_index(
+        self,
+        name: str,
+        dimension: int,
+        cloud: str,
+        region: str,
+        metric: str = "cosine",
+    ):
+        params = {
+            "name": name,
+            "dimension": dimension,
+            "metric": metric,
+            "spec": {
+                "serverless": {
+                    "cloud": cloud,
+                    "region": region
+                }
+            },
+        }
+        async with self.async_client.post(
+            f"{self.base_url}/indexes",
+            headers={"Api-Key": self.api_key},
+            json=params,
+        ) as response:
+            return await response.json(content_type=None)
+        
+    async def _async_describe_index(self, name: str):
+        async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
+            return await response.json(content_type=None)
+
     def __len__(self):
         return self.index.describe_index_stats()["total_vector_count"]
-- 
GitLab