From 3533b8a10094d853c03c23cdd00b64bde6d0ce54 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Thu, 13 Jun 2024 01:40:13 +0800 Subject: [PATCH] feat: methods for async pinecone client --- semantic_router/index/pinecone.py | 135 +++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 1 deletion(-) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index de288524..04cfc796 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -1,3 +1,5 @@ +import aiohttp +import asyncio import hashlib import os import time @@ -46,9 +48,11 @@ class PineconeIndex(BaseIndex): region: str = "us-west-2" host: str = "" client: Any = Field(default=None, exclude=True) + async_client: Any = Field(default=None, exclude=True) index: Optional[Any] = Field(default=None, exclude=True) ServerlessSpec: Any = Field(default=None, exclude=True) namespace: Optional[str] = "" + base_url: Optional[str] = "https://api.pinecone.io" def __init__( self, @@ -60,6 +64,7 @@ class PineconeIndex(BaseIndex): region: str = "us-west-2", host: str = "", namespace: Optional[str] = "", + base_url: Optional[str] = "https://api.pinecone.io", ): super().__init__() self.index_name = index_name @@ -71,11 +76,13 @@ class PineconeIndex(BaseIndex): self.namespace = namespace self.type = "pinecone" self.api_key = api_key or os.getenv("PINECONE_API_KEY") + self.base_url = base_url if self.api_key is None: raise ValueError("Pinecone API key is required.") self.client = self._initialize_client(api_key=self.api_key) + self.async_client = self._initialize_async_client(api_key=self.api_key) def _initialize_client(self, api_key: Optional[str] = None): try: @@ -88,11 +95,22 @@ class PineconeIndex(BaseIndex): "You can install it with: " "`pip install 'semantic-router[pinecone]'`" ) - pinecone_args = {"api_key": api_key, "source_tag": "semantic-router"} + pinecone_args = {"api_key": api_key, "source_tag": "semanticrouter"} if self.namespace: pinecone_args["namespace"] = self.namespace return Pinecone(**pinecone_args) + + def _initialize_async_client(self, api_key: Optional[str] = None): + async_client = aiohttp.ClientSession( + headers={ + "Api-Key": api_key, + "Content-Type": "application/json", + "X-Pinecone-API-Version": "2024-07", + "User-Agent": "source_tag=semanticrouter" + } + ) + return async_client def _init_index(self, force_create: bool = False) -> Union[Any, None]: """Initializing the index can be done after the object has been created @@ -140,6 +158,40 @@ class PineconeIndex(BaseIndex): if index is not None: self.host = self.client.describe_index(self.index_name)["host"] return index + + async def _init_async_index(self, force_create: bool = False) -> Union[Any, None]: + index_stats = None + indexes = await self._async_list_indexes() + index_names = [i["name"] for i in indexes["indexes"]] + index_exists = self.index_name in index_names + dimensions_given = self.dimensions is not None + if dimensions_given and not index_exists: + await self._async_create_index( + name=self.index_name, + dimension=self.dimensions, + metric=self.metric, + cloud=self.cloud, + region=self.region + ) + # TODO describe index and async sleep + index_ready = "false" + while index_ready != "true": + index_stats = await self._async_describe_index(self.index_name) + index_ready = index_stats["status"]["ready"] + await asyncio.sleep(1) + elif index_exists: + index_stats = await self._async_describe_index(self.index_name) + # grab dimensions for the index + self.dimensions = index_stats["dimension"] + elif force_create and not dimensions_given: + raise ValueError( + "Cannot create an index without specifying the dimensions." + ) + else: + # if the index doesn't exist and we don't have the dimensions + # we raise warning + logger.warning("Index could not be initialized.") + self.host = index_stats["host"] if index_stats else None def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records.""" @@ -280,9 +332,90 @@ class PineconeIndex(BaseIndex): scores = [result["score"] for result in results["matches"]] route_names = [result["metadata"]["sr_route"] for result in results["matches"]] return np.array(scores), route_names + + async def aquery( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + if self.async_client is None or self.host is None: + raise ValueError("Async client or host are not initialized.") + query_vector_list = vector.tolist() + if route_filter is not None: + filter_query = {"sr_route": {"$in": route_filter}} + else: + filter_query = None + results = await self._async_query( + vector=query_vector_list, + namespace=self.namespace, + filter=filter_query, + top_k=top_k, + include_metadata=True, + ) + scores = [result["score"] for result in results["matches"]] + route_names = [result["metadata"]["sr_route"] for result in results["matches"]] + return np.array(scores), route_names + def delete_index(self): self.client.delete_index(self.index_name) + # __ASYNC CLIENT METHODS__ + async def _async_query( + self, + vector: list[float], + namespace: str = "", + filter: Optional[dict] = None, + top_k: int = 5, + include_metadata: bool = False, + ): + params = { + "vector": vector, + "namespace": namespace, + "filter": filter, + "top_k": top_k, + "include_metadata": include_metadata, + } + async with self.async_client.post( + f"https://{self.host}/query", + json=params, + ) as response: + return await response.json(content_type=None) + + async def _async_list_indexes(self): + async with self.async_client.get(f"{self.base_url}/indexes") as response: + return await response.json(content_type=None) + + async def _async_create_index( + self, + name: str, + dimension: int, + cloud: str, + region: str, + metric: str = "cosine", + ): + params = { + "name": name, + "dimension": dimension, + "metric": metric, + "spec": { + "serverless": { + "cloud": cloud, + "region": region + } + }, + } + async with self.async_client.post( + f"{self.base_url}/indexes", + headers={"Api-Key": self.api_key}, + json=params, + ) as response: + return await response.json(content_type=None) + + async def _async_describe_index(self, name: str): + async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: + return await response.json(content_type=None) + def __len__(self): return self.index.describe_index_stats()["total_vector_count"] -- GitLab