diff --git a/apps/simple/chatEngine.ts b/apps/simple/chatEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..26ff083d5ca060c0764428ced8b337aca00845ed --- /dev/null +++ b/apps/simple/chatEngine.ts @@ -0,0 +1,24 @@ +// @ts-ignore +import * as readline from "node:readline/promises"; +// @ts-ignore +import { stdin as input, stdout as output } from "node:process"; +import { Document } from "@llamaindex/core/src/Node"; +import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; +import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine"; +import essay from "./essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + const retriever = index.asRetriever(); + const chatEngine = new ContextChatEngine({ retriever }); + const rl = readline.createInterface({ input, output }); + + while (true) { + const query = await rl.question("Query: "); + const response = await chatEngine.achat(query); + console.log(response); + } +} + +main().catch(console.error); diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index e96dc42d74cbd56b4a0c32e8f50448d1e74e7199..b3aebadce620fe69608d1bbb38ceea891f703384 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -4,19 +4,24 @@ import { ChatOpenAI, LLMResult, } from "./LanguageModel"; +import { TextNode } from "./Node"; import { SimplePrompt, + contextSystemPrompt, defaultCondenseQuestionPrompt, messagesToHistoryStr, } from "./Prompt"; import { BaseQueryEngine } from "./QueryEngine"; import { Response } from "./Response"; +import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; interface ChatEngine { chatRepl(): void; achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>; + + reset(): void; } export class SimpleChatEngine implements ChatEngine { @@ -108,3 +113,59 @@ export class CondenseQuestionChatEngine implements ChatEngine { this.chatHistory = []; } } + +export class ContextChatEngine implements ChatEngine { + retriever: BaseRetriever; + chatModel: BaseChatModel; + chatHistory: BaseMessage[]; + + constructor(init: { + retriever: BaseRetriever; + chatModel?: BaseChatModel; + chatHistory?: BaseMessage[]; + }) { + this.retriever = init.retriever; + this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k"); + this.chatHistory = init?.chatHistory ?? []; + } + + chatRepl() { + throw new Error("Method not implemented."); + } + + async achat(message: string, chatHistory?: BaseMessage[] | undefined) { + chatHistory = chatHistory ?? this.chatHistory; + + const sourceNodesWithScore = await this.retriever.aretrieve(message); + + const systemMessage: BaseMessage = { + content: contextSystemPrompt({ + context: sourceNodesWithScore + .map((r) => (r.node as TextNode).text) + .join("\n\n"), + }), + type: "system", + }; + + chatHistory.push({ content: message, type: "human" }); + + const response = await this.chatModel.agenerate([ + systemMessage, + ...chatHistory, + ]); + const text = response.generations[0][0].text; + + chatHistory.push({ content: text, type: "ai" }); + + this.chatHistory = chatHistory; + + return new Response( + text, + sourceNodesWithScore.map((r) => r.node) + ); + } + + reset() { + this.chatHistory = []; + } +} diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index d63e0bbaff8087814f998f02f261a1946deca0d9..da91123119108efc6ef38126a7f3de11e2aef861 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -1,6 +1,7 @@ import { ChatOpenAI } from "./LanguageModel"; import { SimplePrompt } from "./Prompt"; +// TODO change this to LLM class export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; apredict( @@ -10,6 +11,7 @@ export interface BaseLLMPredictor { // stream(prompt: string, options: any): Promise<any>; } +// TODO change this to LLM class export class ChatGPTLLMPredictor implements BaseLLMPredictor { llm: string; retryOnThrottling: boolean; diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 8d7ac97899177681e8a22f77973cbc0207145d6f..3cfb6fac4a04c032963863482804447427ab5aa4 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -308,3 +308,12 @@ export function messagesToHistoryStr(messages: BaseMessage[]) { return acc; }, ""); } + +export const contextSystemPrompt: SimplePrompt = (input) => { + const { context } = input; + + return `Context information is below. +--------------------- +${context} +---------------------`; +}; diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index dfb72e88f577c0cbc1ae00421ce28eb18c7e345d..61928ff2c3fecfda8e25c4ee7d6f62a864d3f4c8 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -8,7 +8,7 @@ import { } from "./storage/vectorStore/types"; export interface BaseRetriever { - aretrieve(query: string): Promise<any>; + aretrieve(query: string): Promise<NodeWithScore[]>; } export class VectorIndexRetriever implements BaseRetriever {