Skip to content
Snippets Groups Projects
Unverified Commit bf6ad5f1 authored by Logan's avatar Logan Committed by GitHub
Browse files

fix as_chat_engine specifying the LLM (#10605)

parent ce167610
No related branches found
No related tags found
No related merge requests found
...@@ -96,6 +96,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): ...@@ -96,6 +96,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
def from_defaults( def from_defaults(
cls, cls,
retriever: BaseRetriever, retriever: BaseRetriever,
llm: Optional[LLM] = None,
service_context: Optional[ServiceContext] = None, service_context: Optional[ServiceContext] = None,
chat_history: Optional[List[ChatMessage]] = None, chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None, memory: Optional[BaseMemory] = None,
...@@ -108,7 +109,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): ...@@ -108,7 +109,7 @@ class CondensePlusContextChatEngine(BaseChatEngine):
**kwargs: Any, **kwargs: Any,
) -> "CondensePlusContextChatEngine": ) -> "CondensePlusContextChatEngine":
"""Initialize a CondensePlusContextChatEngine from default parameters.""" """Initialize a CondensePlusContextChatEngine from default parameters."""
llm = llm_from_settings_or_context(Settings, service_context) llm = llm or llm_from_settings_or_context(Settings, service_context)
chat_history = chat_history or [] chat_history = chat_history or []
memory = memory or ChatMemoryBuffer.from_defaults( memory = memory or ChatMemoryBuffer.from_defaults(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar
from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.base.base_retriever import BaseRetriever
...@@ -400,16 +400,22 @@ class BaseIndex(Generic[IS], ABC): ...@@ -400,16 +400,22 @@ class BaseIndex(Generic[IS], ABC):
llm: Optional[LLMType] = None, llm: Optional[LLMType] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseChatEngine: ) -> BaseChatEngine:
llm = ( service_context = kwargs.get("service_context", self.service_context)
resolve_llm(llm, callback_manager=self._callback_manager)
if llm
else Settings.llm
)
query_engine = self.as_query_engine(llm=llm, **kwargs) if service_context is not None:
llm = (
resolve_llm(llm, callback_manager=self._callback_manager)
if llm
else service_context.llm
)
else:
llm = (
resolve_llm(llm, callback_manager=self._callback_manager)
if llm
else Settings.llm
)
if "service_context" not in kwargs: query_engine = self.as_query_engine(llm=llm, **kwargs)
kwargs["service_context"] = self.service_context
# resolve chat mode # resolve chat mode
if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]: if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]:
...@@ -418,14 +424,14 @@ class BaseIndex(Generic[IS], ABC): ...@@ -418,14 +424,14 @@ class BaseIndex(Generic[IS], ABC):
from llama_index.core.agent import AgentRunner from llama_index.core.agent import AgentRunner
from llama_index.core.tools.query_engine import QueryEngineTool from llama_index.core.tools.query_engine import QueryEngineTool
# get LLM
service_context = cast(ServiceContext, kwargs["service_context"])
llm = service_context.llm
# convert query engine to tool # convert query engine to tool
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs) return AgentRunner.from_llm(
tools=[query_engine_tool],
llm=llm,
**kwargs,
)
if chat_mode == ChatMode.CONDENSE_QUESTION: if chat_mode == ChatMode.CONDENSE_QUESTION:
# NOTE: lazy import # NOTE: lazy import
...@@ -433,6 +439,7 @@ class BaseIndex(Generic[IS], ABC): ...@@ -433,6 +439,7 @@ class BaseIndex(Generic[IS], ABC):
return CondenseQuestionChatEngine.from_defaults( return CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine, query_engine=query_engine,
llm=llm,
**kwargs, **kwargs,
) )
elif chat_mode == ChatMode.CONTEXT: elif chat_mode == ChatMode.CONTEXT:
...@@ -440,6 +447,7 @@ class BaseIndex(Generic[IS], ABC): ...@@ -440,6 +447,7 @@ class BaseIndex(Generic[IS], ABC):
return ContextChatEngine.from_defaults( return ContextChatEngine.from_defaults(
retriever=self.as_retriever(**kwargs), retriever=self.as_retriever(**kwargs),
llm=llm,
**kwargs, **kwargs,
) )
...@@ -448,6 +456,7 @@ class BaseIndex(Generic[IS], ABC): ...@@ -448,6 +456,7 @@ class BaseIndex(Generic[IS], ABC):
return CondensePlusContextChatEngine.from_defaults( return CondensePlusContextChatEngine.from_defaults(
retriever=self.as_retriever(**kwargs), retriever=self.as_retriever(**kwargs),
llm=llm,
**kwargs, **kwargs,
) )
...@@ -455,6 +464,7 @@ class BaseIndex(Generic[IS], ABC): ...@@ -455,6 +464,7 @@ class BaseIndex(Generic[IS], ABC):
from llama_index.core.chat_engine import SimpleChatEngine from llama_index.core.chat_engine import SimpleChatEngine
return SimpleChatEngine.from_defaults( return SimpleChatEngine.from_defaults(
llm=llm,
**kwargs, **kwargs,
) )
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment