From bf6ad5f13939108f75c3042ec7ac7fcc8e4f8819 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Mon, 12 Feb 2024 19:31:05 -0600 Subject: [PATCH] fix as_chat_engine specifying the LLM (#10605) --- .../core/chat_engine/condense_plus_context.py | 3 +- .../llama_index/core/indices/base.py | 38 ++++++++++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py index 6c6b62818a..605905767d 100644 --- a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py +++ b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py @@ -96,6 +96,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): def from_defaults( cls, retriever: BaseRetriever, + llm: Optional[LLM] = None, service_context: Optional[ServiceContext] = None, chat_history: Optional[List[ChatMessage]] = None, memory: Optional[BaseMemory] = None, @@ -108,7 +109,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): **kwargs: Any, ) -> "CondensePlusContextChatEngine": """Initialize a CondensePlusContextChatEngine from default parameters.""" - llm = llm_from_settings_or_context(Settings, service_context) + llm = llm or llm_from_settings_or_context(Settings, service_context) chat_history = chat_history or [] memory = memory or ChatMemoryBuffer.from_defaults( diff --git a/llama-index-core/llama_index/core/indices/base.py b/llama-index-core/llama_index/core/indices/base.py index 0b90785c9d..a657a98852 100644 --- a/llama-index-core/llama_index/core/indices/base.py +++ b/llama-index-core/llama_index/core/indices/base.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast +from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.base.base_retriever import BaseRetriever @@ -400,16 +400,22 @@ class BaseIndex(Generic[IS], ABC): llm: Optional[LLMType] = None, **kwargs: Any, ) -> BaseChatEngine: - llm = ( - resolve_llm(llm, callback_manager=self._callback_manager) - if llm - else Settings.llm - ) + service_context = kwargs.get("service_context", self.service_context) - query_engine = self.as_query_engine(llm=llm, **kwargs) + if service_context is not None: + llm = ( + resolve_llm(llm, callback_manager=self._callback_manager) + if llm + else service_context.llm + ) + else: + llm = ( + resolve_llm(llm, callback_manager=self._callback_manager) + if llm + else Settings.llm + ) - if "service_context" not in kwargs: - kwargs["service_context"] = self.service_context + query_engine = self.as_query_engine(llm=llm, **kwargs) # resolve chat mode if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]: @@ -418,14 +424,14 @@ class BaseIndex(Generic[IS], ABC): from llama_index.core.agent import AgentRunner from llama_index.core.tools.query_engine import QueryEngineTool - # get LLM - service_context = cast(ServiceContext, kwargs["service_context"]) - llm = service_context.llm - # convert query engine to tool query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) - return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs) + return AgentRunner.from_llm( + tools=[query_engine_tool], + llm=llm, + **kwargs, + ) if chat_mode == ChatMode.CONDENSE_QUESTION: # NOTE: lazy import @@ -433,6 +439,7 @@ class BaseIndex(Generic[IS], ABC): return CondenseQuestionChatEngine.from_defaults( query_engine=query_engine, + llm=llm, **kwargs, ) elif chat_mode == ChatMode.CONTEXT: @@ -440,6 +447,7 @@ class BaseIndex(Generic[IS], ABC): return ContextChatEngine.from_defaults( retriever=self.as_retriever(**kwargs), + llm=llm, **kwargs, ) @@ -448,6 +456,7 @@ class BaseIndex(Generic[IS], ABC): return CondensePlusContextChatEngine.from_defaults( retriever=self.as_retriever(**kwargs), + llm=llm, **kwargs, ) @@ -455,6 +464,7 @@ class BaseIndex(Generic[IS], ABC): from llama_index.core.chat_engine import SimpleChatEngine return SimpleChatEngine.from_defaults( + llm=llm, **kwargs, ) else: -- GitLab