diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 5e71da2a76b57fd59236d35469dcb9bfad77341f..15889013b300533e78223b0a2ae5d20d2c274965 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -112,11 +112,11 @@ class PineconeIndex(BaseIndex): region: str = "us-east-1" host: str = "" client: Any = Field(default=None, exclude=True) - async_client: Any = Field(default=None, exclude=True) index: Optional[Any] = Field(default=None, exclude=True) ServerlessSpec: Any = Field(default=None, exclude=True) namespace: Optional[str] = "" base_url: Optional[str] = "https://api.pinecone.io" + headers: dict[str, str] = {} def __init__( self, @@ -132,6 +132,15 @@ class PineconeIndex(BaseIndex): init_async_index: bool = False, ): super().__init__() + self.api_key = api_key or os.getenv("PINECONE_API_KEY") + if not self.api_key: + raise ValueError("Pinecone API key is required.") + self.headers = { + "Api-Key": self.api_key, + "Content-Type": "application/json", + "X-Pinecone-API-Version": "2024-07", + "User-Agent": "source_tag=semanticrouter", + } self.index_name = index_name self.dimensions = dimensions self.metric = metric @@ -153,10 +162,7 @@ class PineconeIndex(BaseIndex): raise ValueError("Pinecone API key is required.") self.client = self._initialize_client(api_key=self.api_key) - if init_async_index: - self.async_client = self._initialize_async_client(api_key=self.api_key) - else: - self.async_client = None + # try initializing index self.index = self._init_index() @@ -181,20 +187,6 @@ class PineconeIndex(BaseIndex): return Pinecone(**pinecone_args) - def _initialize_async_client(self, api_key: Optional[str] = None): - api_key = api_key or self.api_key - if api_key is None: - raise ValueError("Pinecone API key is required.") - async_client = aiohttp.ClientSession( - headers={ - "Api-Key": api_key, - "Content-Type": "application/json", - "X-Pinecone-API-Version": "2024-07", - "User-Agent": "source_tag=semanticrouter", - } - ) - return async_client - def _init_index(self, force_create: bool = False) -> Union[Any, None]: """Initializing the index can be done after the object has been created to allow for the user to set the dimensions and other parameters. @@ -659,8 +651,8 @@ class PineconeIndex(BaseIndex): :rtype: Tuple[np.ndarray, List[str]] :raises ValueError: If the index is not populated. """ - if self.async_client is None or self.host == "": - raise ValueError("Async client or host are not initialized.") + if self.host == "": + raise ValueError("Host is not initialized.") query_vector_list = vector.tolist() if route_filter is not None: filter_query = {"sr_route": {"$in": route_filter}} @@ -693,8 +685,8 @@ class PineconeIndex(BaseIndex): :return: A list of (route_name, utterance) objects. :rtype: List[Tuple] """ - if self.async_client is None or self.host == "": - raise ValueError("Async client or host are not initialized.") + if self.host == "": + raise ValueError("Host is not initialized.") return await self._async_get_routes() @@ -722,15 +714,21 @@ class PineconeIndex(BaseIndex): } if self.host == "": raise ValueError("self.host is not initialized.") - async with self.async_client.post( - f"https://{self.host}/query", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/query", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_list_indexes(self): - async with self.async_client.get(f"{self.base_url}/indexes") as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.base_url}/indexes", + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_upsert( self, @@ -741,12 +739,14 @@ class PineconeIndex(BaseIndex): "vectors": vectors, "namespace": namespace, } - async with self.async_client.post( - f"https://{self.host}/vectors/upsert", - json=params, - ) as response: - res = await response.json(content_type=None) - return res + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/vectors/upsert", + json=params, + headers=self.headers, + ) as response: + res = await response.json(content_type=None) + return res async def _async_create_index( self, @@ -762,26 +762,34 @@ class PineconeIndex(BaseIndex): "metric": metric, "spec": {"serverless": {"cloud": cloud, "region": region}}, } - async with self.async_client.post( - f"{self.base_url}/indexes", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/indexes", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_delete(self, ids: list[str], namespace: str = ""): params = { "ids": ids, "namespace": namespace, } - async with self.async_client.post( - f"https://{self.host}/vectors/delete", - json=params, - ) as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.post( + f"https://{self.host}/vectors/delete", + json=params, + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_describe_index(self, name: str): - async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: - return await response.json(content_type=None) + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.base_url}/indexes/{name}", + headers=self.headers, + ) as response: + return await response.json(content_type=None) async def _async_get_all( self, prefix: Optional[str] = None, include_metadata: bool = False @@ -815,33 +823,38 @@ class PineconeIndex(BaseIndex): params["namespace"] = self.namespace metadata = [] - while True: - if next_page_token: - params["paginationToken"] = next_page_token - - async with self.async_client.get( - list_url, params=params, headers={"Api-Key": self.api_key} - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Error fetching vectors: {error_text}") + async with aiohttp.ClientSession() as session: + while True: + if next_page_token: + params["paginationToken"] = next_page_token + + async with session.get( + list_url, + params=params, + headers=self.headers, + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching vectors: {error_text}") + break + + response_data = await response.json(content_type=None) + + vector_ids = [vec["id"] for vec in response_data.get("vectors", [])] + if not vector_ids: break + all_vector_ids.extend(vector_ids) - response_data = await response.json(content_type=None) + if include_metadata: + metadata_tasks = [ + self._async_fetch_metadata(id) for id in vector_ids + ] + metadata_results = await asyncio.gather(*metadata_tasks) + metadata.extend(metadata_results) - vector_ids = [vec["id"] for vec in response_data.get("vectors", [])] - if not vector_ids: - break - all_vector_ids.extend(vector_ids) - - if include_metadata: - metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids] - metadata_results = await asyncio.gather(*metadata_tasks) - metadata.extend(metadata_results) - - next_page_token = response_data.get("pagination", {}).get("next") - if not next_page_token: - break + next_page_token = response_data.get("pagination", {}).get("next") + if not next_page_token: + break return all_vector_ids, metadata @@ -851,7 +864,7 @@ class PineconeIndex(BaseIndex): namespace: str | None = None, ) -> dict: """Fetch metadata for a single vector ID asynchronously using the - async_client. + ClientSession. :param vector_id: The ID of the vector to fetch metadata for. :type vector_id: str @@ -873,27 +886,26 @@ class PineconeIndex(BaseIndex): elif self.namespace: params["namespace"] = [self.namespace] - headers = { - "Api-Key": self.api_key, - } - - async with self.async_client.get( - url, params=params, headers=headers - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Error fetching metadata: {error_text}") - return {} - - try: - response_data = await response.json(content_type=None) - except Exception as e: - logger.warning(f"No metadata found for vector {vector_id}: {e}") - return {} - - return ( - response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {}) - ) + async with aiohttp.ClientSession() as session: + async with session.get( + url, params=params, headers=self.headers + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Error fetching metadata: {error_text}") + return {} + + try: + response_data = await response.json(content_type=None) + except Exception as e: + logger.warning(f"No metadata found for vector {vector_id}: {e}") + return {} + + return ( + response_data.get("vectors", {}) + .get(vector_id, {}) + .get("metadata", {}) + ) def __len__(self): namespace_stats = self.index.describe_index_stats()["namespaces"].get( diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index ab115db09a9e8811553c83aaf0c8621ead2cee31..a210cc4cd32075dc8f43e522316d95d2a1cecfe1 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -834,6 +834,7 @@ class TestSemanticRouter: index=pineconeindex, auto_sync="local", ) + time.sleep(PINECONE_SLEEP) # allow for index to be updated @retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) def check_query_result(): diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index feae4b5d103519bea2fe8687a8b6dd64e3e662f5..d40a57fb0a1feb80d34aa2a8d44da21b44ba3db9 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -416,6 +416,7 @@ class TestSemanticRouter: index=pinecone_index, auto_sync="local", ) + time.sleep(PINECONE_SLEEP) @retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) def check_sync(): @@ -449,6 +450,7 @@ class TestSemanticRouter: index=pinecone_index, auto_sync="remote", ) + time.sleep(PINECONE_SLEEP) @retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) def check_sync(): @@ -474,13 +476,14 @@ class TestSemanticRouter: index=pinecone_index, auto_sync="local", ) - time.sleep(PINECONE_SLEEP) # allow for index to be populated + time.sleep(PINECONE_SLEEP * 2) # allow for index to be populated route_layer = router_cls( encoder=openai_encoder, routes=routes_2, index=pinecone_index, auto_sync="merge-force-local", ) + time.sleep(PINECONE_SLEEP * 2) @retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) def check_sync(): @@ -750,8 +753,8 @@ class TestAsyncSemanticRouter: encoder=openai_encoder, routes=routes, index=index, auto_sync="local" ) if index_cls is PineconeIndex: - await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated - assert route_layer.async_is_synced() + await asyncio.sleep(PINECONE_SLEEP * 2) # allow for index to be populated + assert await route_layer.async_is_synced() @pytest.mark.skipif( os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" @@ -898,6 +901,7 @@ class TestAsyncSemanticRouter: index=pinecone_index, auto_sync="merge-force-local", ) + await asyncio.sleep(PINECONE_SLEEP * 2) @async_retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) async def check_sync(): @@ -905,11 +909,11 @@ class TestAsyncSemanticRouter: assert await route_layer.async_is_synced() # now confirm utterances are correct local_utterances = await route_layer.index.aget_utterances( - include_metadata=True + include_metadata=False ) # we sort to ensure order is the same # TODO JB: there is a bug here where if we include_metadata=True it fails - local_utterances.sort(key=lambda x: x.to_str(include_metadata=True)) + local_utterances.sort(key=lambda x: x.to_str(include_metadata=False)) assert local_utterances == [ Utterance(route="Route 1", utterance="Hello"), Utterance(route="Route 1", utterance="Hi"), @@ -972,6 +976,7 @@ class TestAsyncSemanticRouter: index=pinecone_index, auto_sync="merge-force-remote", ) + await asyncio.sleep(PINECONE_SLEEP) @async_retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) async def check_sync(): @@ -1046,6 +1051,7 @@ class TestAsyncSemanticRouter: index=pinecone_index, auto_sync="merge", ) + await asyncio.sleep(PINECONE_SLEEP) @async_retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP) async def check_sync():