From bc17d27c0180262fa3d4f07149640a8f38625b4b Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Wed, 20 Mar 2024 18:30:06 -0600 Subject: [PATCH] remove stray service context usage (#12133) --- .../core/chat_engine/condense_question.py | 12 ++---------- .../node_parser/relational/base_element.py | 19 +++---------------- .../core/query_engine/router_query_engine.py | 9 ++++++--- 3 files changed, 11 insertions(+), 29 deletions(-) diff --git a/llama-index-core/llama_index/core/chat_engine/condense_question.py b/llama-index-core/llama_index/core/chat_engine/condense_question.py index c90e288a8e..b864db0246 100644 --- a/llama-index-core/llama_index/core/chat_engine/condense_question.py +++ b/llama-index-core/llama_index/core/chat_engine/condense_question.py @@ -12,7 +12,6 @@ from llama_index.core.chat_engine.types import ( StreamingAgentChatResponse, ) from llama_index.core.chat_engine.utils import response_gen_from_query_engine -from llama_index.core.embeddings.mock_embed_model import MockEmbedding from llama_index.core.base.llms.generic_utils import messages_to_history_str from llama_index.core.llms.llm import LLM from llama_index.core.memory import BaseMemory, ChatMemoryBuffer @@ -22,6 +21,7 @@ from llama_index.core.service_context_elements.llm_predictor import LLMPredictor from llama_index.core.settings import ( Settings, callback_manager_from_settings_or_context, + llm_from_settings_or_context, ) from llama_index.core.tools import ToolOutput @@ -86,15 +86,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): """Initialize a CondenseQuestionChatEngine from default parameters.""" condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT - if llm is None: - service_context = service_context or ServiceContext.from_defaults( - embed_model=MockEmbedding(embed_dim=2) - ) - llm = service_context.llm - else: - service_context = service_context or ServiceContext.from_defaults( - llm=llm, embed_model=MockEmbedding(embed_dim=2) - ) + llm = llm or llm_from_settings_or_context(Settings, service_context) chat_history = chat_history or [] memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm) diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py index 12f72887d8..62a068cfd5 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py @@ -143,22 +143,9 @@ class BaseElementNodeParser(NodeParser): def extract_table_summaries(self, elements: List[Element]) -> None: """Go through elements, extract out summaries that are tables.""" from llama_index.core.indices.list.base import SummaryIndex - from llama_index.core.service_context import ServiceContext + from llama_index.core.settings import Settings - if self.llm: - llm = self.llm - else: - try: - from llama_index.llms.openai import OpenAI # pants: no-infer-dep - except ImportError as e: - raise ImportError( - "`llama-index-llms-openai` package not found." - " Please install with `pip install llama-index-llms-openai`." - ) - llm = OpenAI() - llm = cast(LLM, llm) - - service_context = ServiceContext.from_defaults(llm=llm, embed_model=None) + llm = self.llm or Settings.llm table_context_list = [] for idx, element in tqdm(enumerate(elements)): @@ -178,7 +165,7 @@ class BaseElementNodeParser(NodeParser): async def _get_table_output(table_context: str, summary_query_str: str) -> Any: index = SummaryIndex.from_documents( - [Document(text=table_context)], service_context=service_context + [Document(text=table_context)], ) query_engine = index.as_query_engine(llm=llm, output_cls=TableOutput) try: diff --git a/llama-index-core/llama_index/core/query_engine/router_query_engine.py b/llama-index-core/llama_index/core/query_engine/router_query_engine.py index 97a2fd1a80..9f01599015 100644 --- a/llama-index-core/llama_index/core/query_engine/router_query_engine.py +++ b/llama-index-core/llama_index/core/query_engine/router_query_engine.py @@ -332,17 +332,20 @@ class ToolRetrieverRouterQueryEngine(BaseQueryEngine): def __init__( self, retriever: ObjectRetriever[QueryEngineTool], + llm: Optional[LLM] = None, service_context: Optional[ServiceContext] = None, summarizer: Optional[TreeSummarize] = None, ) -> None: - self.service_context = service_context or ServiceContext.from_defaults() + llm = llm or llm_from_settings_or_context(Settings, service_context) self._summarizer = summarizer or TreeSummarize( - service_context=self.service_context, + llm=llm, summary_template=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, ) self._retriever = retriever - super().__init__(self.service_context.callback_manager) + super().__init__( + callback_manager_from_settings_or_context(Settings, service_context) + ) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" -- GitLab