Skip to content
Snippets Groups Projects
Unverified Commit d322a320 authored by kevin-yang-racap's avatar kevin-yang-racap Committed by GitHub
Browse files

Ky/dynamic pg triplet retrieval limit (#16928)

parent f60869ab
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,7 @@ class LLMSynonymRetriever(BasePGRetriever):
] = DEFAULT_SYNONYM_EXPAND_TEMPLATE,
max_keywords: int = 10,
path_depth: int = 1,
limit: int = 30,
output_parsing_fn: Optional[Callable] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
......@@ -70,6 +71,7 @@ class LLMSynonymRetriever(BasePGRetriever):
self._output_parsing_fn = output_parsing_fn
self._max_keywords = max_keywords
self._path_depth = path_depth
self._limit = limit
super().__init__(
graph_store=graph_store,
include_text=include_text,
......@@ -86,27 +88,35 @@ class LLMSynonymRetriever(BasePGRetriever):
# capitalize to normalize with ingestion
return [x.strip().capitalize() for x in matches if x.strip()]
def _prepare_matches(self, matches: List[str]) -> List[NodeWithScore]:
def _prepare_matches(
self, matches: List[str], limit: Optional[int] = None
) -> List[NodeWithScore]:
kg_nodes = self._graph_store.get(ids=matches)
triplets = self._graph_store.get_rel_map(
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
return self._get_nodes_with_score(triplets)
async def _aprepare_matches(self, matches: List[str]) -> List[NodeWithScore]:
async def _aprepare_matches(
self, matches: List[str], limit: Optional[int] = None
) -> List[NodeWithScore]:
kg_nodes = await self._graph_store.aget(ids=matches)
triplets = await self._graph_store.aget_rel_map(
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
return self._get_nodes_with_score(triplets)
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
def retrieve_from_graph(
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
response = self._llm.predict(
self._synonym_prompt,
query_str=query_bundle.query_str,
......@@ -114,10 +124,10 @@ class LLMSynonymRetriever(BasePGRetriever):
)
matches = self._parse_llm_output(response)
return self._prepare_matches(matches)
return self._prepare_matches(matches, limit=limit or self._limit)
async def aretrieve_from_graph(
self, query_bundle: QueryBundle
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
response = await self._llm.apredict(
self._synonym_prompt,
......@@ -126,4 +136,4 @@ class LLMSynonymRetriever(BasePGRetriever):
)
matches = self._parse_llm_output(response)
return await self._aprepare_matches(matches)
return await self._aprepare_matches(matches, limit=limit or self._limit)
......@@ -49,6 +49,7 @@ class VectorContextRetriever(BasePGRetriever):
vector_store: Optional[BasePydanticVectorStore] = None,
similarity_top_k: int = 4,
path_depth: int = 1,
limit: int = 30,
similarity_score: Optional[float] = None,
filters: Optional[MetadataFilters] = None,
**kwargs: Any,
......@@ -58,6 +59,7 @@ class VectorContextRetriever(BasePGRetriever):
self._similarity_top_k = similarity_top_k
self._vector_store = vector_store
self._path_depth = path_depth
self._limit = limit
self._similarity_score = similarity_score
self._filters = filters
......@@ -112,7 +114,9 @@ class VectorContextRetriever(BasePGRetriever):
**self._retriever_kwargs,
)
def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
def retrieve_from_graph(
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
vector_store_query = self._get_vector_store_query(query_bundle)
triplets = []
......@@ -126,7 +130,10 @@ class VectorContextRetriever(BasePGRetriever):
kg_ids = [node.id for node in kg_nodes]
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
elif self._vector_store is not None:
......@@ -136,7 +143,10 @@ class VectorContextRetriever(BasePGRetriever):
scores = query_result.similarities
kg_nodes = self._graph_store.get(ids=kg_ids)
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
elif query_result.ids is not None and query_result.similarities is not None:
......@@ -144,7 +154,10 @@ class VectorContextRetriever(BasePGRetriever):
scores = query_result.similarities
kg_nodes = self._graph_store.get(ids=kg_ids)
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
for triplet in triplets:
......@@ -174,7 +187,7 @@ class VectorContextRetriever(BasePGRetriever):
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])
async def aretrieve_from_graph(
self, query_bundle: QueryBundle
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
vector_store_query = await self._aget_vector_store_query(query_bundle)
......@@ -189,7 +202,10 @@ class VectorContextRetriever(BasePGRetriever):
kg_nodes, scores = result
kg_ids = [node.id for node in kg_nodes]
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
elif self._vector_store is not None:
......@@ -199,7 +215,10 @@ class VectorContextRetriever(BasePGRetriever):
scores = query_result.similarities
kg_nodes = await self._graph_store.aget(ids=kg_ids)
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
elif query_result.ids is not None and query_result.similarities is not None:
......@@ -207,7 +226,10 @@ class VectorContextRetriever(BasePGRetriever):
scores = query_result.similarities
kg_nodes = await self._graph_store.aget(ids=kg_ids)
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)
for triplet in triplets:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment