From ebf3bc19fd0385e4b1de07b7fa91d5329bd68ab3 Mon Sep 17 00:00:00 2001 From: Yi Ding <yi.s.ding@gmail.com> Date: Thu, 6 Jul 2023 20:59:10 -0700 Subject: [PATCH] ContextChatEngine v1 --- apps/simple/chatEngine.ts | 24 ++++++++++++ packages/core/src/ChatEngine.ts | 61 +++++++++++++++++++++++++++++++ packages/core/src/LLMPredictor.ts | 2 + packages/core/src/Prompt.ts | 9 +++++ packages/core/src/Retriever.ts | 2 +- 5 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 apps/simple/chatEngine.ts diff --git a/apps/simple/chatEngine.ts b/apps/simple/chatEngine.ts new file mode 100644 index 000000000..26ff083d5 --- /dev/null +++ b/apps/simple/chatEngine.ts @@ -0,0 +1,24 @@ +// @ts-ignore +import * as readline from "node:readline/promises"; +// @ts-ignore +import { stdin as input, stdout as output } from "node:process"; +import { Document } from "@llamaindex/core/src/Node"; +import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; +import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine"; +import essay from "./essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + const retriever = index.asRetriever(); + const chatEngine = new ContextChatEngine({ retriever }); + const rl = readline.createInterface({ input, output }); + + while (true) { + const query = await rl.question("Query: "); + const response = await chatEngine.achat(query); + console.log(response); + } +} + +main().catch(console.error); diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index e96dc42d7..b3aebadce 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -4,19 +4,24 @@ import { ChatOpenAI, LLMResult, } from "./LanguageModel"; +import { TextNode } from "./Node"; import { SimplePrompt, + contextSystemPrompt, defaultCondenseQuestionPrompt, messagesToHistoryStr, } from "./Prompt"; import { BaseQueryEngine } from "./QueryEngine"; import { Response } from "./Response"; +import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; interface ChatEngine { chatRepl(): void; achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>; + + reset(): void; } export class SimpleChatEngine implements ChatEngine { @@ -108,3 +113,59 @@ export class CondenseQuestionChatEngine implements ChatEngine { this.chatHistory = []; } } + +export class ContextChatEngine implements ChatEngine { + retriever: BaseRetriever; + chatModel: BaseChatModel; + chatHistory: BaseMessage[]; + + constructor(init: { + retriever: BaseRetriever; + chatModel?: BaseChatModel; + chatHistory?: BaseMessage[]; + }) { + this.retriever = init.retriever; + this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k"); + this.chatHistory = init?.chatHistory ?? []; + } + + chatRepl() { + throw new Error("Method not implemented."); + } + + async achat(message: string, chatHistory?: BaseMessage[] | undefined) { + chatHistory = chatHistory ?? this.chatHistory; + + const sourceNodesWithScore = await this.retriever.aretrieve(message); + + const systemMessage: BaseMessage = { + content: contextSystemPrompt({ + context: sourceNodesWithScore + .map((r) => (r.node as TextNode).text) + .join("\n\n"), + }), + type: "system", + }; + + chatHistory.push({ content: message, type: "human" }); + + const response = await this.chatModel.agenerate([ + systemMessage, + ...chatHistory, + ]); + const text = response.generations[0][0].text; + + chatHistory.push({ content: text, type: "ai" }); + + this.chatHistory = chatHistory; + + return new Response( + text, + sourceNodesWithScore.map((r) => r.node) + ); + } + + reset() { + this.chatHistory = []; + } +} diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index d63e0bbaf..da9112311 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,6 +1,7 @@ import { ChatOpenAI } from "./LanguageModel"; import { SimplePrompt } from "./Prompt"; +// TODO change this to LLM class export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; apredict( @@ -10,6 +11,7 @@ export interface BaseLLMPredictor { // stream(prompt: string, options: any): Promise<any>; } +// TODO change this to LLM class export class ChatGPTLLMPredictor implements BaseLLMPredictor { llm: string; retryOnThrottling: boolean; diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 8d7ac9789..3cfb6fac4 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -308,3 +308,12 @@ export function messagesToHistoryStr(messages: BaseMessage[]) { return acc; }, ""); } + +export const contextSystemPrompt: SimplePrompt = (input) => { + const { context } = input; + + return `Context information is below. +--------------------- +${context} +---------------------`; +}; diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index dfb72e88f..61928ff2c 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -8,7 +8,7 @@ import { } from "./storage/vectorStore/types"; export interface BaseRetriever { - aretrieve(query: string): Promise<any>; + aretrieve(query: string): Promise<NodeWithScore[]>; } export class VectorIndexRetriever implements BaseRetriever { -- GitLab