Skip to content
Snippets Groups Projects
Unverified Commit fbc32773 authored by Sourabh Desai's avatar Sourabh Desai Committed by GitHub
Browse files

fix small client bug for edge case where playground pipeline exists with same...

fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline (#10994)

* fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline

* add elif condition as sanity check

* use enum.value
parent d36077f1
No related branches found
No related tags found
No related merge requests found
from typing import Any, Dict, List, Optional
from llama_index_client import TextNodeWithScore
from llama_index_client.resources.pipeline.client import OMIT
from llama_index_client.resources.pipeline.client import OMIT, PipelineType
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.constants import DEFAULT_PROJECT_NAME
......@@ -58,13 +58,19 @@ class LlamaCloudRetriever(BaseRetriever):
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve from the platform."""
pipelines = self._client.pipeline.search_pipelines(
project_name=self.project_name, pipeline_name=self.name
project_name=self.project_name,
pipeline_name=self.name,
pipeline_type=PipelineType.MANAGED.value,
)
if len(pipelines) != 1:
if len(pipelines) == 0:
raise ValueError(
f"Unknown index name {self.name}. Please confirm a "
"managed index with this name exists."
)
elif len(pipelines) > 1:
raise ValueError(
f"Multiple pipelines found with name {self.name} in project {self.project_name}"
)
pipeline = pipelines[0]
if pipeline.id is None:
......@@ -90,9 +96,19 @@ class LlamaCloudRetriever(BaseRetriever):
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously retrieve from the platform."""
pipelines = await self._aclient.pipeline.search_pipelines(
project_name=self.project_name, pipeline_name=self.name
project_name=self.project_name,
pipeline_name=self.name,
pipeline_type=PipelineType.MANAGED.value,
)
assert len(pipelines) == 1
if len(pipelines) == 0:
raise ValueError(
f"Unknown index name {self.name}. Please confirm a "
"managed index with this name exists."
)
elif len(pipelines) > 1:
raise ValueError(
f"Multiple pipelines found with name {self.name} in project {self.project_name}"
)
pipeline = pipelines[0]
if pipeline.id is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment