diff --git a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py index 17938c1ed77a5bf6edc3c968f1e120812054f5ac..4248af17fc41be8f48068270709e2b541a7bac49 100644 --- a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py +++ b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py @@ -59,6 +59,7 @@ class LLMSynonymRetriever(BasePGRetriever): ] = DEFAULT_SYNONYM_EXPAND_TEMPLATE, max_keywords: int = 10, path_depth: int = 1, + limit: int = 30, output_parsing_fn: Optional[Callable] = None, llm: Optional[LLM] = None, **kwargs: Any, @@ -70,6 +71,7 @@ class LLMSynonymRetriever(BasePGRetriever): self._output_parsing_fn = output_parsing_fn self._max_keywords = max_keywords self._path_depth = path_depth + self._limit = limit super().__init__( graph_store=graph_store, include_text=include_text, @@ -86,27 +88,35 @@ class LLMSynonymRetriever(BasePGRetriever): # capitalize to normalize with ingestion return [x.strip().capitalize() for x in matches if x.strip()] - def _prepare_matches(self, matches: List[str]) -> List[NodeWithScore]: + def _prepare_matches( + self, matches: List[str], limit: Optional[int] = None + ) -> List[NodeWithScore]: kg_nodes = self._graph_store.get(ids=matches) triplets = self._graph_store.get_rel_map( kg_nodes, depth=self._path_depth, + limit=limit or self._limit, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) - async def _aprepare_matches(self, matches: List[str]) -> List[NodeWithScore]: + async def _aprepare_matches( + self, matches: List[str], limit: Optional[int] = None + ) -> List[NodeWithScore]: kg_nodes = await self._graph_store.aget(ids=matches) triplets = await self._graph_store.aget_rel_map( kg_nodes, depth=self._path_depth, + limit=limit or self._limit, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) - def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + def retrieve_from_graph( + self, query_bundle: QueryBundle, limit: Optional[int] = None + ) -> List[NodeWithScore]: response = self._llm.predict( self._synonym_prompt, query_str=query_bundle.query_str, @@ -114,10 +124,10 @@ class LLMSynonymRetriever(BasePGRetriever): ) matches = self._parse_llm_output(response) - return self._prepare_matches(matches) + return self._prepare_matches(matches, limit=limit or self._limit) async def aretrieve_from_graph( - self, query_bundle: QueryBundle + self, query_bundle: QueryBundle, limit: Optional[int] = None ) -> List[NodeWithScore]: response = await self._llm.apredict( self._synonym_prompt, @@ -126,4 +136,4 @@ class LLMSynonymRetriever(BasePGRetriever): ) matches = self._parse_llm_output(response) - return await self._aprepare_matches(matches) + return await self._aprepare_matches(matches, limit=limit or self._limit) diff --git a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py index effbcd7e5471e91a07ddd17d81f8bb3231eb8584..93611fd3b14bee9988349b8c6445beabbd743c51 100644 --- a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py +++ b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py @@ -49,6 +49,7 @@ class VectorContextRetriever(BasePGRetriever): vector_store: Optional[BasePydanticVectorStore] = None, similarity_top_k: int = 4, path_depth: int = 1, + limit: int = 30, similarity_score: Optional[float] = None, filters: Optional[MetadataFilters] = None, **kwargs: Any, @@ -58,6 +59,7 @@ class VectorContextRetriever(BasePGRetriever): self._similarity_top_k = similarity_top_k self._vector_store = vector_store self._path_depth = path_depth + self._limit = limit self._similarity_score = similarity_score self._filters = filters @@ -112,7 +114,9 @@ class VectorContextRetriever(BasePGRetriever): **self._retriever_kwargs, ) - def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + def retrieve_from_graph( + self, query_bundle: QueryBundle, limit: Optional[int] = None + ) -> List[NodeWithScore]: vector_store_query = self._get_vector_store_query(query_bundle) triplets = [] @@ -126,7 +130,10 @@ class VectorContextRetriever(BasePGRetriever): kg_ids = [node.id for node in kg_nodes] triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif self._vector_store is not None: @@ -136,7 +143,10 @@ class VectorContextRetriever(BasePGRetriever): scores = query_result.similarities kg_nodes = self._graph_store.get(ids=kg_ids) triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif query_result.ids is not None and query_result.similarities is not None: @@ -144,7 +154,10 @@ class VectorContextRetriever(BasePGRetriever): scores = query_result.similarities kg_nodes = self._graph_store.get(ids=kg_ids) triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) for triplet in triplets: @@ -174,7 +187,7 @@ class VectorContextRetriever(BasePGRetriever): return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k]) async def aretrieve_from_graph( - self, query_bundle: QueryBundle + self, query_bundle: QueryBundle, limit: Optional[int] = None ) -> List[NodeWithScore]: vector_store_query = await self._aget_vector_store_query(query_bundle) @@ -189,7 +202,10 @@ class VectorContextRetriever(BasePGRetriever): kg_nodes, scores = result kg_ids = [node.id for node in kg_nodes] triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif self._vector_store is not None: @@ -199,7 +215,10 @@ class VectorContextRetriever(BasePGRetriever): scores = query_result.similarities kg_nodes = await self._graph_store.aget(ids=kg_ids) triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif query_result.ids is not None and query_result.similarities is not None: @@ -207,7 +226,10 @@ class VectorContextRetriever(BasePGRetriever): scores = query_result.similarities kg_nodes = await self._graph_store.aget(ids=kg_ids) triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) for triplet in triplets: