From bf6ad5f13939108f75c3042ec7ac7fcc8e4f8819 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Mon, 12 Feb 2024 19:31:05 -0600
Subject: [PATCH] fix as_chat_engine specifying the LLM (#10605)

---
 .../core/chat_engine/condense_plus_context.py |  3 +-
 .../llama_index/core/indices/base.py          | 38 ++++++++++++-------
 2 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
index 6c6b62818a..605905767d 100644
--- a/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
+++ b/llama-index-core/llama_index/core/chat_engine/condense_plus_context.py
@@ -96,6 +96,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
     def from_defaults(
         cls,
         retriever: BaseRetriever,
+        llm: Optional[LLM] = None,
         service_context: Optional[ServiceContext] = None,
         chat_history: Optional[List[ChatMessage]] = None,
         memory: Optional[BaseMemory] = None,
@@ -108,7 +109,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
         **kwargs: Any,
     ) -> "CondensePlusContextChatEngine":
         """Initialize a CondensePlusContextChatEngine from default parameters."""
-        llm = llm_from_settings_or_context(Settings, service_context)
+        llm = llm or llm_from_settings_or_context(Settings, service_context)
 
         chat_history = chat_history or []
         memory = memory or ChatMemoryBuffer.from_defaults(
diff --git a/llama-index-core/llama_index/core/indices/base.py b/llama-index-core/llama_index/core/indices/base.py
index 0b90785c9d..a657a98852 100644
--- a/llama-index-core/llama_index/core/indices/base.py
+++ b/llama-index-core/llama_index/core/indices/base.py
@@ -2,7 +2,7 @@
 
 import logging
 from abc import ABC, abstractmethod
-from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast
+from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar
 
 from llama_index.core.base.base_query_engine import BaseQueryEngine
 from llama_index.core.base.base_retriever import BaseRetriever
@@ -400,16 +400,22 @@ class BaseIndex(Generic[IS], ABC):
         llm: Optional[LLMType] = None,
         **kwargs: Any,
     ) -> BaseChatEngine:
-        llm = (
-            resolve_llm(llm, callback_manager=self._callback_manager)
-            if llm
-            else Settings.llm
-        )
+        service_context = kwargs.get("service_context", self.service_context)
 
-        query_engine = self.as_query_engine(llm=llm, **kwargs)
+        if service_context is not None:
+            llm = (
+                resolve_llm(llm, callback_manager=self._callback_manager)
+                if llm
+                else service_context.llm
+            )
+        else:
+            llm = (
+                resolve_llm(llm, callback_manager=self._callback_manager)
+                if llm
+                else Settings.llm
+            )
 
-        if "service_context" not in kwargs:
-            kwargs["service_context"] = self.service_context
+        query_engine = self.as_query_engine(llm=llm, **kwargs)
 
         # resolve chat mode
         if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]:
@@ -418,14 +424,14 @@ class BaseIndex(Generic[IS], ABC):
             from llama_index.core.agent import AgentRunner
             from llama_index.core.tools.query_engine import QueryEngineTool
 
-            # get LLM
-            service_context = cast(ServiceContext, kwargs["service_context"])
-            llm = service_context.llm
-
             # convert query engine to tool
             query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
 
-            return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs)
+            return AgentRunner.from_llm(
+                tools=[query_engine_tool],
+                llm=llm,
+                **kwargs,
+            )
 
         if chat_mode == ChatMode.CONDENSE_QUESTION:
             # NOTE: lazy import
@@ -433,6 +439,7 @@ class BaseIndex(Generic[IS], ABC):
 
             return CondenseQuestionChatEngine.from_defaults(
                 query_engine=query_engine,
+                llm=llm,
                 **kwargs,
             )
         elif chat_mode == ChatMode.CONTEXT:
@@ -440,6 +447,7 @@ class BaseIndex(Generic[IS], ABC):
 
             return ContextChatEngine.from_defaults(
                 retriever=self.as_retriever(**kwargs),
+                llm=llm,
                 **kwargs,
             )
 
@@ -448,6 +456,7 @@ class BaseIndex(Generic[IS], ABC):
 
             return CondensePlusContextChatEngine.from_defaults(
                 retriever=self.as_retriever(**kwargs),
+                llm=llm,
                 **kwargs,
             )
 
@@ -455,6 +464,7 @@ class BaseIndex(Generic[IS], ABC):
             from llama_index.core.chat_engine import SimpleChatEngine
 
             return SimpleChatEngine.from_defaults(
+                llm=llm,
                 **kwargs,
             )
         else:
-- 
GitLab