From fbc32773dcdc92078441e36f086e730dc62c4114 Mon Sep 17 00:00:00 2001 From: Sourabh Desai <sourabhdesai@gmail.com> Date: Mon, 19 Feb 2024 15:38:49 -0800 Subject: [PATCH] fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline (#10994) * fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline * add elif condition as sanity check * use enum.value --- .../indices/managed/llama_cloud/retriever.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py index 28fa4f1fcb..8390c84c49 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional from llama_index_client import TextNodeWithScore -from llama_index_client.resources.pipeline.client import OMIT +from llama_index_client.resources.pipeline.client import OMIT, PipelineType from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.constants import DEFAULT_PROJECT_NAME @@ -58,13 +58,19 @@ class LlamaCloudRetriever(BaseRetriever): def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve from the platform.""" pipelines = self._client.pipeline.search_pipelines( - project_name=self.project_name, pipeline_name=self.name + project_name=self.project_name, + pipeline_name=self.name, + pipeline_type=PipelineType.MANAGED.value, ) - if len(pipelines) != 1: + if len(pipelines) == 0: raise ValueError( f"Unknown index name {self.name}. Please confirm a " "managed index with this name exists." ) + elif len(pipelines) > 1: + raise ValueError( + f"Multiple pipelines found with name {self.name} in project {self.project_name}" + ) pipeline = pipelines[0] if pipeline.id is None: @@ -90,9 +96,19 @@ class LlamaCloudRetriever(BaseRetriever): async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Asynchronously retrieve from the platform.""" pipelines = await self._aclient.pipeline.search_pipelines( - project_name=self.project_name, pipeline_name=self.name + project_name=self.project_name, + pipeline_name=self.name, + pipeline_type=PipelineType.MANAGED.value, ) - assert len(pipelines) == 1 + if len(pipelines) == 0: + raise ValueError( + f"Unknown index name {self.name}. Please confirm a " + "managed index with this name exists." + ) + elif len(pipelines) > 1: + raise ValueError( + f"Multiple pipelines found with name {self.name} in project {self.project_name}" + ) pipeline = pipelines[0] if pipeline.id is None: -- GitLab