diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 1e1bc33938f938db3eec4645db3a8bef86b757d0..9659a1f2fb41b0fbbbd5ee5806ff77199ab59823 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -31,6 +31,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; chatHistory: ChatHistory; contextGenerator: ContextGenerator; + systemPrompt?: string; constructor(init: { retriever: BaseRetriever; @@ -38,9 +39,9 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatHistory?: ChatMessage[]; contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; + systemPrompt?: string; }) { super(); - this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = getHistory(init?.chatHistory); @@ -49,6 +50,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { contextSystemPrompt: init?.contextSystemPrompt, nodePostprocessors: init?.nodePostprocessors, }); + this.systemPrompt = init.systemPrompt; } protected _getPromptModules(): Record<string, ContextGenerator> { @@ -71,7 +73,6 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { message, chatHistory, ); - if (stream) { const stream = await this.chatModel.chat({ messages: requestMessages.messages, @@ -113,9 +114,16 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }); const textOnly = extractText(message); const context = await this.contextGenerator.generate(textOnly); - const messages = await chatHistory.requestMessages( - context ? [context.message] : undefined, - ); + const systemMessage = this.prependSystemPrompt(context.message); + const messages = await chatHistory.requestMessages([systemMessage]); return { nodes: context.nodes, messages }; } + + private prependSystemPrompt(message: ChatMessage): ChatMessage { + if (!this.systemPrompt) return message; + return { + ...message, + content: this.systemPrompt.trim() + "\n" + message.content, + }; + } }