diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 5e71da2a76b57fd59236d35469dcb9bfad77341f..2b65c80b4ff863d1d86dc8e540bad58035424b31 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -153,10 +153,14 @@ class PineconeIndex(BaseIndex): raise ValueError("Pinecone API key is required.") self.client = self._initialize_client(api_key=self.api_key) - if init_async_index: - self.async_client = self._initialize_async_client(api_key=self.api_key) - else: - self.async_client = None + + self.api_key = api_key + self.headers = { + "Api-Key": self.api_key, + "Content-Type": "application/json", + "X-Pinecone-API-Version": "2024-07", + "User-Agent": "source_tag=semanticrouter", + } # try initializing index self.index = self._init_index() @@ -659,8 +663,8 @@ class PineconeIndex(BaseIndex): :rtype: Tuple[np.ndarray, List[str]] :raises ValueError: If the index is not populated. """ - if self.async_client is None or self.host == "": - raise ValueError("Async client or host are not initialized.") + if self.host == "": + raise ValueError("Host is not initialized.") query_vector_list = vector.tolist() if route_filter is not None: filter_query = {"sr_route": {"$in": route_filter}} @@ -693,8 +697,8 @@ class PineconeIndex(BaseIndex): :return: A list of (route_name, utterance) objects. :rtype: List[Tuple] """ - if self.async_client is None or self.host == "": - raise ValueError("Async client or host are not initialized.") + if self.host == "": + raise ValueError("Host is not initialized.") return await self._async_get_routes() @@ -722,15 +726,21 @@ class PineconeIndex(BaseIndex): } if self.host == "": raise ValueError("self.host is not initialized.") - async with self.async_client.post( - f"https://{self.host}/query", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/query", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_list_indexes(self): - async with self.async_client.get(f"{self.base_url}/indexes") as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.base_url}/indexes", + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_upsert( self, @@ -741,12 +751,14 @@ class PineconeIndex(BaseIndex): "vectors": vectors, "namespace": namespace, } - async with self.async_client.post( - f"https://{self.host}/vectors/upsert", - json=params, - ) as response: - res = await response.json(content_type=None) - return res + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/vectors/upsert", + json=params, + headers=self.headers, + ) as response: + res = await response.json(content_type=None) + return res async def _async_create_index( self, @@ -762,26 +774,34 @@ class PineconeIndex(BaseIndex): "metric": metric, "spec": {"serverless": {"cloud": cloud, "region": region}}, } - async with self.async_client.post( - f"{self.base_url}/indexes", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/indexes", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_delete(self, ids: list[str], namespace: str = ""): params = { "ids": ids, "namespace": namespace, } - async with self.async_client.post( - f"https://{self.host}/vectors/delete", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/vectors/delete", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_describe_index(self, name: str): - async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.base_url}/indexes/{name}", + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_get_all( self, prefix: Optional[str] = None, include_metadata: bool = False @@ -819,13 +839,16 @@ class PineconeIndex(BaseIndex): if next_page_token: params["paginationToken"] = next_page_token - async with self.async_client.get( - list_url, params=params, headers={"Api-Key": self.api_key} - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Error fetching vectors: {error_text}") - break + async with aiohttp.ClientSession() as session: + async with session.get( + list_url, + params=params, + headers=self.headers, + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching vectors: {error_text}") + break response_data = await response.json(content_type=None) @@ -877,23 +900,24 @@ class PineconeIndex(BaseIndex): "Api-Key": self.api_key, } - async with self.async_client.get( - url, params=params, headers=headers - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Error fetching metadata: {error_text}") - return {} - - try: - response_data = await response.json(content_type=None) - except Exception as e: - logger.warning(f"No metadata found for vector {vector_id}: {e}") - return {} - - return ( - response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) - ) + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching metadata: {error_text}") + return {} + + try: + response_data = await response.json(content_type=None) + except Exception as e: + logger.warning(f"No metadata found for vector {vector_id}: {e}") + return {} + + return ( + response_data.get("vectors", {}) + .get(vector_id, {}) + .get("metadata", {}) + ) def __len__(self): namespace_stats = self.index.describe_index_stats()["namespaces"].get(