diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py index 28fa4f1fcbb99dd29a01992fa0f7dcd7653a49ff..8390c84c49338db1d8774206bf6aa15b98958b76 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional from llama_index_client import TextNodeWithScore -from llama_index_client.resources.pipeline.client import OMIT +from llama_index_client.resources.pipeline.client import OMIT, PipelineType from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.constants import DEFAULT_PROJECT_NAME @@ -58,13 +58,19 @@ class LlamaCloudRetriever(BaseRetriever): def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve from the platform.""" pipelines = self._client.pipeline.search_pipelines( - project_name=self.project_name, pipeline_name=self.name + project_name=self.project_name, + pipeline_name=self.name, + pipeline_type=PipelineType.MANAGED.value, ) - if len(pipelines) != 1: + if len(pipelines) == 0: raise ValueError( f"Unknown index name {self.name}. Please confirm a " "managed index with this name exists." ) + elif len(pipelines) > 1: + raise ValueError( + f"Multiple pipelines found with name {self.name} in project {self.project_name}" + ) pipeline = pipelines[0] if pipeline.id is None: @@ -90,9 +96,19 @@ class LlamaCloudRetriever(BaseRetriever): async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Asynchronously retrieve from the platform.""" pipelines = await self._aclient.pipeline.search_pipelines( - project_name=self.project_name, pipeline_name=self.name + project_name=self.project_name, + pipeline_name=self.name, + pipeline_type=PipelineType.MANAGED.value, ) - assert len(pipelines) == 1 + if len(pipelines) == 0: + raise ValueError( + f"Unknown index name {self.name}. Please confirm a " + "managed index with this name exists." + ) + elif len(pipelines) > 1: + raise ValueError( + f"Multiple pipelines found with name {self.name} in project {self.project_name}" + ) pipeline = pipelines[0] if pipeline.id is None: