Skip to content
Snippets Groups Projects
Unverified Commit fbc32773 authored by Sourabh Desai's avatar Sourabh Desai Committed by GitHub
Browse files

fix small client bug for edge case where playground pipeline exists with same...

fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline (#10994)

* fix small client bug for edge case where playground pipeline exists with same name as the managed pipeline

* add elif condition as sanity check

* use enum.value
parent d36077f1
No related branches found
No related tags found
No related merge requests found
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from llama_index_client import TextNodeWithScore from llama_index_client import TextNodeWithScore
from llama_index_client.resources.pipeline.client import OMIT from llama_index_client.resources.pipeline.client import OMIT, PipelineType
from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.constants import DEFAULT_PROJECT_NAME from llama_index.core.constants import DEFAULT_PROJECT_NAME
...@@ -58,13 +58,19 @@ class LlamaCloudRetriever(BaseRetriever): ...@@ -58,13 +58,19 @@ class LlamaCloudRetriever(BaseRetriever):
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve from the platform.""" """Retrieve from the platform."""
pipelines = self._client.pipeline.search_pipelines( pipelines = self._client.pipeline.search_pipelines(
project_name=self.project_name, pipeline_name=self.name project_name=self.project_name,
pipeline_name=self.name,
pipeline_type=PipelineType.MANAGED.value,
) )
if len(pipelines) != 1: if len(pipelines) == 0:
raise ValueError( raise ValueError(
f"Unknown index name {self.name}. Please confirm a " f"Unknown index name {self.name}. Please confirm a "
"managed index with this name exists." "managed index with this name exists."
) )
elif len(pipelines) > 1:
raise ValueError(
f"Multiple pipelines found with name {self.name} in project {self.project_name}"
)
pipeline = pipelines[0] pipeline = pipelines[0]
if pipeline.id is None: if pipeline.id is None:
...@@ -90,9 +96,19 @@ class LlamaCloudRetriever(BaseRetriever): ...@@ -90,9 +96,19 @@ class LlamaCloudRetriever(BaseRetriever):
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously retrieve from the platform.""" """Asynchronously retrieve from the platform."""
pipelines = await self._aclient.pipeline.search_pipelines( pipelines = await self._aclient.pipeline.search_pipelines(
project_name=self.project_name, pipeline_name=self.name project_name=self.project_name,
pipeline_name=self.name,
pipeline_type=PipelineType.MANAGED.value,
) )
assert len(pipelines) == 1 if len(pipelines) == 0:
raise ValueError(
f"Unknown index name {self.name}. Please confirm a "
"managed index with this name exists."
)
elif len(pipelines) > 1:
raise ValueError(
f"Multiple pipelines found with name {self.name} in project {self.project_name}"
)
pipeline = pipelines[0] pipeline = pipelines[0]
if pipeline.id is None: if pipeline.id is None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment