diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index 27430eaf889c46d62aa81568e0ba03929a43c4be..560718955d3690620d0af743e16d2119892a4317 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -14,9 +14,11 @@ from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse from llama_index.llm_predictor.base import LLMPredictorType from llama_index.llms.generic_utils import messages_to_history_str +from llama_index.llms.llm import LLM from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.service_context import ServiceContext +from llama_index.token_counter.mock_embed_model import MockEmbedding from llama_index.tools import ToolOutput logger = logging.getLogger(__name__) @@ -74,13 +76,21 @@ class CondenseQuestionChatEngine(BaseChatEngine): verbose: bool = False, system_prompt: Optional[str] = None, prefix_messages: Optional[List[ChatMessage]] = None, + llm: Optional[LLM] = None, **kwargs: Any, ) -> "CondenseQuestionChatEngine": """Initialize a CondenseQuestionChatEngine from default parameters.""" condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT - service_context = service_context or ServiceContext.from_defaults() - llm = service_context.llm + if llm is None: + service_context = service_context or ServiceContext.from_defaults( + embed_model=MockEmbedding(embed_dim=2) + ) + llm = service_context.llm + else: + service_context = service_context or ServiceContext.from_defaults( + llm=llm, embed_model=MockEmbedding(embed_dim=2) + ) chat_history = chat_history or [] memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm) diff --git a/llama_index/command_line/rag.py b/llama_index/command_line/rag.py index 57afe777e12adb911cb3652fc14709e4f0baaf04..f57c880d0c26a6a2ca15341b3d44c2bcb2da011a 100644 --- a/llama_index/command_line/rag.py +++ b/llama_index/command_line/rag.py @@ -14,6 +14,7 @@ from llama_index import ( from llama_index.bridge.pydantic import BaseModel, Field, validator from llama_index.chat_engine import CondenseQuestionChatEngine from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse +from llama_index.embeddings.base import BaseEmbedding from llama_index.ingestion import IngestionPipeline from llama_index.llms import LLM, OpenAI from llama_index.query_engine import CustomQueryEngine @@ -98,7 +99,18 @@ class RagCLI(BaseModel): fn=query_input, output_key="output", req_params={"query_str"} ) llm = cast(LLM, values["llm"]) - service_context = ServiceContext.from_defaults(llm=llm) + + # get embed_model from transformations if possible + embed_model = None + if ingestion_pipeline.transformations is not None: + for transformation in ingestion_pipeline.transformations: + if isinstance(transformation, BaseEmbedding): + embed_model = transformation + break + + service_context = ServiceContext.from_defaults( + llm=llm, embed_model=embed_model or "default" + ) retriever = VectorStoreIndex.from_vector_store( ingestion_pipeline.vector_store, service_context=service_context ).as_retriever(similarity_top_k=8) @@ -130,6 +142,11 @@ class RagCLI(BaseModel): if chat_engine is not None: return chat_engine + if values.get("query_pipeline", None) is None: + values["query_pipeline"] = cls.query_pipeline_from_ingestion_pipeline( + query_pipeline=None, values=values + ) + query_pipeline = cast(QueryPipeline, values["query_pipeline"]) if query_pipeline is None: return None