diff --git a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py index 6c6b62818a6f1cc741f9953ea002f72c511286cb..605905767de79f46969fd35838c243d8fa81e475 100644 --- a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py +++ b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py @@ -96,6 +96,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): def from_defaults( cls, retriever: BaseRetriever, + llm: Optional[LLM] = None, service_context: Optional[ServiceContext] = None, chat_history: Optional[List[ChatMessage]] = None, memory: Optional[BaseMemory] = None, @@ -108,7 +109,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): **kwargs: Any, ) -> "CondensePlusContextChatEngine": """Initialize a CondensePlusContextChatEngine from default parameters.""" - llm = llm_from_settings_or_context(Settings, service_context) + llm = llm or llm_from_settings_or_context(Settings, service_context) chat_history = chat_history or [] memory = memory or ChatMemoryBuffer.from_defaults( diff --git a/llama-index-core/llama_index/core/indices/base.py b/llama-index-core/llama_index/core/indices/base.py index 0b90785c9dffc044f175dad428893039910c1c94..a657a988526edc7bf40fd356bcd871468ab55d7c 100644 --- a/llama-index-core/llama_index/core/indices/base.py +++ b/llama-index-core/llama_index/core/indices/base.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.base.base_retriever import BaseRetriever @@ -400,16 +400,22 @@ class BaseIndex(Generic[IS], ABC): llm: Optional[LLMType] = None, **kwargs: Any, ) -> BaseChatEngine: - llm = ( - resolve_llm(llm, callback_manager=self._callback_manager) - if llm - else Settings.llm - ) + service_context = kwargs.get("service_context", self.service_context) - query_engine = self.as_query_engine(llm=llm, **kwargs) + if service_context is not None: + llm = ( + resolve_llm(llm, callback_manager=self._callback_manager) + if llm + else service_context.llm + ) + else: + llm = ( + resolve_llm(llm, callback_manager=self._callback_manager) + if llm + else Settings.llm + ) - if "service_context" not in kwargs: - kwargs["service_context"] = self.service_context + query_engine = self.as_query_engine(llm=llm, **kwargs) # resolve chat mode if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]: @@ -418,14 +424,14 @@ class BaseIndex(Generic[IS], ABC): from llama_index.core.agent import AgentRunner from llama_index.core.tools.query_engine import QueryEngineTool - # get LLM - service_context = cast(ServiceContext, kwargs["service_context"]) - llm = service_context.llm - # convert query engine to tool query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) - return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs) + return AgentRunner.from_llm( + tools=[query_engine_tool], + llm=llm, + **kwargs, + ) if chat_mode == ChatMode.CONDENSE_QUESTION: # NOTE: lazy import @@ -433,6 +439,7 @@ class BaseIndex(Generic[IS], ABC): return CondenseQuestionChatEngine.from_defaults( query_engine=query_engine, + llm=llm, **kwargs, ) elif chat_mode == ChatMode.CONTEXT: @@ -440,6 +447,7 @@ class BaseIndex(Generic[IS], ABC): return ContextChatEngine.from_defaults( retriever=self.as_retriever(**kwargs), + llm=llm, **kwargs, ) @@ -448,6 +456,7 @@ class BaseIndex(Generic[IS], ABC): return CondensePlusContextChatEngine.from_defaults( retriever=self.as_retriever(**kwargs), + llm=llm, **kwargs, ) @@ -455,6 +464,7 @@ class BaseIndex(Generic[IS], ABC): from llama_index.core.chat_engine import SimpleChatEngine return SimpleChatEngine.from_defaults( + llm=llm, **kwargs, ) else: