Skip to content
Snippets Groups Projects
Unverified Commit 0b711e0a authored by Jeffrey (Dongkyu) Kim's avatar Jeffrey (Dongkyu) Kim Committed by GitHub
Browse files

Add async_postprocess_nodes at RankGPT Postprocessor Nodes (#12620)


* add async_postprocess_nodes and rankgpt test code

* fix typo and rename async functions

* use pytest.mark.asyncio

---------

Co-authored-by: default avatarjeffrey <vkefhdl1@gmail.com>
parent 5c59394c
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,40 @@ class RankGPTRerank(BaseNodePostprocessor): ...@@ -69,6 +69,40 @@ class RankGPTRerank(BaseNodePostprocessor):
messages = self.create_permutation_instruction(item=items) messages = self.create_permutation_instruction(item=items)
permutation = self.run_llm(messages=messages) permutation = self.run_llm(messages=messages)
return self._llm_result_to_nodes(permutation, nodes, items)
async def _apostprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
items = {
"query": query_bundle.query_str,
"hits": [{"content": node.get_content()} for node in nodes],
}
messages = self.create_permutation_instruction(item=items)
permutation = await self.arun_llm(messages=messages)
return self._llm_result_to_nodes(permutation, nodes, items)
async def apostprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
query_str: Optional[str] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes asynchronously."""
if query_str is not None and query_bundle is not None:
raise ValueError("Cannot specify both query_str and query_bundle")
elif query_str is not None:
query_bundle = QueryBundle(query_str)
else:
pass
return await self._apostprocess_nodes(nodes, query_bundle)
def _llm_result_to_nodes(
self, permutation: ChatResponse, nodes: List[NodeWithScore], items: Dict
) -> List[NodeWithScore]:
if permutation.message is not None and permutation.message.content is not None: if permutation.message is not None and permutation.message.content is not None:
rerank_ranks = self._receive_permutation( rerank_ranks = self._receive_permutation(
items, str(permutation.message.content) items, str(permutation.message.content)
...@@ -136,6 +170,9 @@ class RankGPTRerank(BaseNodePostprocessor): ...@@ -136,6 +170,9 @@ class RankGPTRerank(BaseNodePostprocessor):
def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse: def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
return self.llm.chat(messages) return self.llm.chat(messages)
async def arun_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
return await self.llm.achat(messages)
def _clean_response(self, response: str) -> str: def _clean_response(self, response: str) -> str:
new_response = "" new_response = ""
for c in response: for c in response:
......
from typing import Any
from unittest.mock import patch
import asyncio
import pytest
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
from llama_index.core.llms.mock import MockLLM
from llama_index.core.postprocessor.rankGPT_rerank import RankGPTRerank
from llama_index.core.schema import TextNode, NodeWithScore
def mock_rankgpt_chat(self: Any, messages, **kwargs: Any) -> ChatResponse:
return ChatResponse(
message=ChatMessage(role=MessageRole.SYSTEM, content="[2] > [1] > [3]")
)
async def mock_rankgpt_achat(self, messages, **kwargs: Any) -> ChatResponse:
# Mock api call
await asyncio.sleep(1)
return ChatResponse(
message=ChatMessage(role=MessageRole.SYSTEM, content="[2] > [1] > [3]")
)
nodes = [
TextNode(text="Test"),
TextNode(text="Test2"),
TextNode(text="Test3"),
]
nodes_with_score = [NodeWithScore(node=n) for n in nodes]
@patch.object(
MockLLM,
"chat",
mock_rankgpt_chat,
)
def test_rankgpt_rerank():
rankgpt_rerank = RankGPTRerank(
top_n=2,
llm=MockLLM(),
)
result = rankgpt_rerank.postprocess_nodes(nodes_with_score, query_str="Test query")
assert len(result) == 2
assert result[0].node.get_content() == "Test2"
assert result[1].node.get_content() == "Test"
@patch.object(
MockLLM,
"achat",
mock_rankgpt_achat,
)
@pytest.mark.asyncio()
async def test_rankgpt_rerank_async():
rankgpt_rerank = RankGPTRerank(
top_n=2,
llm=MockLLM(),
)
result = await rankgpt_rerank.apostprocess_nodes(
nodes_with_score, query_str="Test query"
)
assert len(result) == 2
assert result[0].node.get_content() == "Test2"
assert result[1].node.get_content() == "Test"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment