From aa0f586330789fc732225d29875bd78787d753fe Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Mon, 20 May 2024 15:57:58 +0700 Subject: [PATCH] feat: allow adding system prompt to chat engine (#855) Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .../core/src/engines/chat/ContextChatEngine.ts | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 1e1bc3393..9659a1f2f 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -31,6 +31,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatModel: LLM; chatHistory: ChatHistory; contextGenerator: ContextGenerator; + systemPrompt?: string; constructor(init: { retriever: BaseRetriever; @@ -38,9 +39,9 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { chatHistory?: ChatMessage[]; contextSystemPrompt?: ContextSystemPrompt; nodePostprocessors?: BaseNodePostprocessor[]; + systemPrompt?: string; }) { super(); - this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); this.chatHistory = getHistory(init?.chatHistory); @@ -49,6 +50,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { contextSystemPrompt: init?.contextSystemPrompt, nodePostprocessors: init?.nodePostprocessors, }); + this.systemPrompt = init.systemPrompt; } protected _getPromptModules(): Record<string, ContextGenerator> { @@ -71,7 +73,6 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { message, chatHistory, ); - if (stream) { const stream = await this.chatModel.chat({ messages: requestMessages.messages, @@ -113,9 +114,16 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { }); const textOnly = extractText(message); const context = await this.contextGenerator.generate(textOnly); - const messages = await chatHistory.requestMessages( - context ? [context.message] : undefined, - ); + const systemMessage = this.prependSystemPrompt(context.message); + const messages = await chatHistory.requestMessages([systemMessage]); return { nodes: context.nodes, messages }; } + + private prependSystemPrompt(message: ChatMessage): ChatMessage { + if (!this.systemPrompt) return message; + return { + ...message, + content: this.systemPrompt.trim() + "\n" + message.content, + }; + } } -- GitLab