Skip to content
Snippets Groups Projects
Unverified Commit 254d4f2b authored by Alex Feel's avatar Alex Feel Committed by GitHub
Browse files

feat: add lazy LLM initialization in RankGPTRerank (#10648)

parent ddd86175
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,6 @@ from llama_index.core.prompts.default_prompts import RANKGPT_RERANK_PROMPT ...@@ -9,7 +9,6 @@ from llama_index.core.prompts.default_prompts import RANKGPT_RERANK_PROMPT
from llama_index.core.prompts.mixin import PromptDictType from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.schema import NodeWithScore, QueryBundle from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.utils import print_text from llama_index.core.utils import print_text
from llama_index.llms.openai import OpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
...@@ -19,10 +18,7 @@ class RankGPTRerank(BaseNodePostprocessor): ...@@ -19,10 +18,7 @@ class RankGPTRerank(BaseNodePostprocessor):
"""RankGPT-based reranker.""" """RankGPT-based reranker."""
top_n: int = Field(default=5, description="Top N nodes to return from reranking.") top_n: int = Field(default=5, description="Top N nodes to return from reranking.")
llm: LLM = Field( llm: Optional[LLM] = None
default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"),
description="LLM to use for rankGPT",
)
verbose: bool = Field( verbose: bool = Field(
default=False, description="Whether to print intermediate steps." default=False, description="Whether to print intermediate steps."
) )
...@@ -49,6 +45,18 @@ class RankGPTRerank(BaseNodePostprocessor): ...@@ -49,6 +45,18 @@ class RankGPTRerank(BaseNodePostprocessor):
def class_name(cls) -> str: def class_name(cls) -> str:
return "RankGPTRerank" return "RankGPTRerank"
def _ensure_llm(self) -> None:
if not self.llm:
try:
from llama_index.llms.openai import OpenAI
self.llm = OpenAI(model="gpt-3.5-turbo-16k")
except ImportError:
raise RuntimeError(
"OpenAI LLM is not available. Please install `llama-index-llms-openai` "
"or provide an alternative LLM instance."
)
def _postprocess_nodes( def _postprocess_nodes(
self, self,
nodes: List[NodeWithScore], nodes: List[NodeWithScore],
...@@ -129,6 +137,7 @@ class RankGPTRerank(BaseNodePostprocessor): ...@@ -129,6 +137,7 @@ class RankGPTRerank(BaseNodePostprocessor):
return messages return messages
def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse: def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
self._ensure_llm()
return self.llm.chat(messages) return self.llm.chat(messages)
def _clean_response(self, response: str) -> str: def _clean_response(self, response: str) -> str:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment