diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-rankgpt-rerank/llama_index/postprocessor/rankgpt_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-rankgpt-rerank/llama_index/postprocessor/rankgpt_rerank/base.py index 8eabab92c697ab38be5570cefa479e8f70122b71..1c3ecb77b871a99d7a4c8ce8ea53591f27a7956d 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-rankgpt-rerank/llama_index/postprocessor/rankgpt_rerank/base.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-rankgpt-rerank/llama_index/postprocessor/rankgpt_rerank/base.py @@ -9,7 +9,6 @@ from llama_index.core.prompts.default_prompts import RANKGPT_RERANK_PROMPT from llama_index.core.prompts.mixin import PromptDictType from llama_index.core.schema import NodeWithScore, QueryBundle from llama_index.core.utils import print_text -from llama_index.llms.openai import OpenAI logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -19,10 +18,7 @@ class RankGPTRerank(BaseNodePostprocessor): """RankGPT-based reranker.""" top_n: int = Field(default=5, description="Top N nodes to return from reranking.") - llm: LLM = Field( - default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"), - description="LLM to use for rankGPT", - ) + llm: Optional[LLM] = None verbose: bool = Field( default=False, description="Whether to print intermediate steps." ) @@ -49,6 +45,18 @@ class RankGPTRerank(BaseNodePostprocessor): def class_name(cls) -> str: return "RankGPTRerank" + def _ensure_llm(self) -> None: + if not self.llm: + try: + from llama_index.llms.openai import OpenAI + + self.llm = OpenAI(model="gpt-3.5-turbo-16k") + except ImportError: + raise RuntimeError( + "OpenAI LLM is not available. Please install `llama-index-llms-openai` " + "or provide an alternative LLM instance." + ) + def _postprocess_nodes( self, nodes: List[NodeWithScore], @@ -129,6 +137,7 @@ class RankGPTRerank(BaseNodePostprocessor): return messages def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse: + self._ensure_llm() return self.llm.chat(messages) def _clean_response(self, response: str) -> str: