diff --git a/llama-index-core/llama_index/core/postprocessor/rankGPT_rerank.py b/llama-index-core/llama_index/core/postprocessor/rankGPT_rerank.py index 9659f10416af933d2968d0e3844ff1945173d43f..67f12dd7db56821766f9df1842aa67aff49566b1 100644 --- a/llama-index-core/llama_index/core/postprocessor/rankGPT_rerank.py +++ b/llama-index-core/llama_index/core/postprocessor/rankGPT_rerank.py @@ -69,6 +69,40 @@ class RankGPTRerank(BaseNodePostprocessor): messages = self.create_permutation_instruction(item=items) permutation = self.run_llm(messages=messages) + return self._llm_result_to_nodes(permutation, nodes, items) + + async def _apostprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + items = { + "query": query_bundle.query_str, + "hits": [{"content": node.get_content()} for node in nodes], + } + + messages = self.create_permutation_instruction(item=items) + permutation = await self.arun_llm(messages=messages) + return self._llm_result_to_nodes(permutation, nodes, items) + + async def apostprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + query_str: Optional[str] = None, + ) -> List[NodeWithScore]: + """Postprocess nodes asynchronously.""" + if query_str is not None and query_bundle is not None: + raise ValueError("Cannot specify both query_str and query_bundle") + elif query_str is not None: + query_bundle = QueryBundle(query_str) + else: + pass + return await self._apostprocess_nodes(nodes, query_bundle) + + def _llm_result_to_nodes( + self, permutation: ChatResponse, nodes: List[NodeWithScore], items: Dict + ) -> List[NodeWithScore]: if permutation.message is not None and permutation.message.content is not None: rerank_ranks = self._receive_permutation( items, str(permutation.message.content) @@ -136,6 +170,9 @@ class RankGPTRerank(BaseNodePostprocessor): def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse: return self.llm.chat(messages) + async def arun_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse: + return await self.llm.achat(messages) + def _clean_response(self, response: str) -> str: new_response = "" for c in response: diff --git a/llama-index-core/tests/postprocessor/test_rankgpt_rerank.py b/llama-index-core/tests/postprocessor/test_rankgpt_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..04ebe3060e79d4fe9ec9fe7e36833b4b3617b39c --- /dev/null +++ b/llama-index-core/tests/postprocessor/test_rankgpt_rerank.py @@ -0,0 +1,66 @@ +from typing import Any +from unittest.mock import patch +import asyncio + +import pytest +from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole +from llama_index.core.llms.mock import MockLLM +from llama_index.core.postprocessor.rankGPT_rerank import RankGPTRerank +from llama_index.core.schema import TextNode, NodeWithScore + + +def mock_rankgpt_chat(self: Any, messages, **kwargs: Any) -> ChatResponse: + return ChatResponse( + message=ChatMessage(role=MessageRole.SYSTEM, content="[2] > [1] > [3]") + ) + + +async def mock_rankgpt_achat(self, messages, **kwargs: Any) -> ChatResponse: + # Mock api call + await asyncio.sleep(1) + return ChatResponse( + message=ChatMessage(role=MessageRole.SYSTEM, content="[2] > [1] > [3]") + ) + + +nodes = [ + TextNode(text="Test"), + TextNode(text="Test2"), + TextNode(text="Test3"), +] +nodes_with_score = [NodeWithScore(node=n) for n in nodes] + + +@patch.object( + MockLLM, + "chat", + mock_rankgpt_chat, +) +def test_rankgpt_rerank(): + rankgpt_rerank = RankGPTRerank( + top_n=2, + llm=MockLLM(), + ) + result = rankgpt_rerank.postprocess_nodes(nodes_with_score, query_str="Test query") + assert len(result) == 2 + assert result[0].node.get_content() == "Test2" + assert result[1].node.get_content() == "Test" + + +@patch.object( + MockLLM, + "achat", + mock_rankgpt_achat, +) +@pytest.mark.asyncio() +async def test_rankgpt_rerank_async(): + rankgpt_rerank = RankGPTRerank( + top_n=2, + llm=MockLLM(), + ) + result = await rankgpt_rerank.apostprocess_nodes( + nodes_with_score, query_str="Test query" + ) + assert len(result) == 2 + assert result[0].node.get_content() == "Test2" + assert result[1].node.get_content() == "Test"