From 8921cbff4e4c6e58324393d12d991eb121ff9326 Mon Sep 17 00:00:00 2001
From: Jerry Liu <jerryjliu98@gmail.com>
Date: Thu, 1 Feb 2024 10:30:37 -0800
Subject: [PATCH] patch condense question custom llm in RAG CLI  (#10390)

---
 llama_index/chat_engine/condense_question.py | 14 ++++++++++++--
 llama_index/command_line/rag.py              | 19 ++++++++++++++++++-
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py
index 27430eaf88..560718955d 100644
--- a/llama_index/chat_engine/condense_question.py
+++ b/llama_index/chat_engine/condense_question.py
@@ -14,9 +14,11 @@ from llama_index.core.llms.types import ChatMessage, MessageRole
 from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
 from llama_index.llm_predictor.base import LLMPredictorType
 from llama_index.llms.generic_utils import messages_to_history_str
+from llama_index.llms.llm import LLM
 from llama_index.memory import BaseMemory, ChatMemoryBuffer
 from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
 from llama_index.service_context import ServiceContext
+from llama_index.token_counter.mock_embed_model import MockEmbedding
 from llama_index.tools import ToolOutput
 
 logger = logging.getLogger(__name__)
@@ -74,13 +76,21 @@ class CondenseQuestionChatEngine(BaseChatEngine):
         verbose: bool = False,
         system_prompt: Optional[str] = None,
         prefix_messages: Optional[List[ChatMessage]] = None,
+        llm: Optional[LLM] = None,
         **kwargs: Any,
     ) -> "CondenseQuestionChatEngine":
         """Initialize a CondenseQuestionChatEngine from default parameters."""
         condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT
 
-        service_context = service_context or ServiceContext.from_defaults()
-        llm = service_context.llm
+        if llm is None:
+            service_context = service_context or ServiceContext.from_defaults(
+                embed_model=MockEmbedding(embed_dim=2)
+            )
+            llm = service_context.llm
+        else:
+            service_context = service_context or ServiceContext.from_defaults(
+                llm=llm, embed_model=MockEmbedding(embed_dim=2)
+            )
 
         chat_history = chat_history or []
         memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm)
diff --git a/llama_index/command_line/rag.py b/llama_index/command_line/rag.py
index 57afe777e1..f57c880d0c 100644
--- a/llama_index/command_line/rag.py
+++ b/llama_index/command_line/rag.py
@@ -14,6 +14,7 @@ from llama_index import (
 from llama_index.bridge.pydantic import BaseModel, Field, validator
 from llama_index.chat_engine import CondenseQuestionChatEngine
 from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse
+from llama_index.embeddings.base import BaseEmbedding
 from llama_index.ingestion import IngestionPipeline
 from llama_index.llms import LLM, OpenAI
 from llama_index.query_engine import CustomQueryEngine
@@ -98,7 +99,18 @@ class RagCLI(BaseModel):
             fn=query_input, output_key="output", req_params={"query_str"}
         )
         llm = cast(LLM, values["llm"])
-        service_context = ServiceContext.from_defaults(llm=llm)
+
+        # get embed_model from transformations if possible
+        embed_model = None
+        if ingestion_pipeline.transformations is not None:
+            for transformation in ingestion_pipeline.transformations:
+                if isinstance(transformation, BaseEmbedding):
+                    embed_model = transformation
+                    break
+
+        service_context = ServiceContext.from_defaults(
+            llm=llm, embed_model=embed_model or "default"
+        )
         retriever = VectorStoreIndex.from_vector_store(
             ingestion_pipeline.vector_store, service_context=service_context
         ).as_retriever(similarity_top_k=8)
@@ -130,6 +142,11 @@ class RagCLI(BaseModel):
         if chat_engine is not None:
             return chat_engine
 
+        if values.get("query_pipeline", None) is None:
+            values["query_pipeline"] = cls.query_pipeline_from_ingestion_pipeline(
+                query_pipeline=None, values=values
+            )
+
         query_pipeline = cast(QueryPipeline, values["query_pipeline"])
         if query_pipeline is None:
             return None
-- 
GitLab