From ebf3bc19fd0385e4b1de07b7fa91d5329bd68ab3 Mon Sep 17 00:00:00 2001
From: Yi Ding <yi.s.ding@gmail.com>
Date: Thu, 6 Jul 2023 20:59:10 -0700
Subject: [PATCH] ContextChatEngine v1

---
 apps/simple/chatEngine.ts         | 24 ++++++++++++
 packages/core/src/ChatEngine.ts   | 61 +++++++++++++++++++++++++++++++
 packages/core/src/LLMPredictor.ts |  2 +
 packages/core/src/Prompt.ts       |  9 +++++
 packages/core/src/Retriever.ts    |  2 +-
 5 files changed, 97 insertions(+), 1 deletion(-)
 create mode 100644 apps/simple/chatEngine.ts

diff --git a/apps/simple/chatEngine.ts b/apps/simple/chatEngine.ts
new file mode 100644
index 000000000..26ff083d5
--- /dev/null
+++ b/apps/simple/chatEngine.ts
@@ -0,0 +1,24 @@
+// @ts-ignore
+import * as readline from "node:readline/promises";
+// @ts-ignore
+import { stdin as input, stdout as output } from "node:process";
+import { Document } from "@llamaindex/core/src/Node";
+import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
+import { ContextChatEngine } from "@llamaindex/core/src/ChatEngine";
+import essay from "./essay";
+
+async function main() {
+  const document = new Document({ text: essay });
+  const index = await VectorStoreIndex.fromDocuments([document]);
+  const retriever = index.asRetriever();
+  const chatEngine = new ContextChatEngine({ retriever });
+  const rl = readline.createInterface({ input, output });
+
+  while (true) {
+    const query = await rl.question("Query: ");
+    const response = await chatEngine.achat(query);
+    console.log(response);
+  }
+}
+
+main().catch(console.error);
diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index e96dc42d7..b3aebadce 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -4,19 +4,24 @@ import {
   ChatOpenAI,
   LLMResult,
 } from "./LanguageModel";
+import { TextNode } from "./Node";
 import {
   SimplePrompt,
+  contextSystemPrompt,
   defaultCondenseQuestionPrompt,
   messagesToHistoryStr,
 } from "./Prompt";
 import { BaseQueryEngine } from "./QueryEngine";
 import { Response } from "./Response";
+import { BaseRetriever } from "./Retriever";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
 
 interface ChatEngine {
   chatRepl(): void;
 
   achat(message: string, chatHistory?: BaseMessage[]): Promise<Response>;
+
+  reset(): void;
 }
 
 export class SimpleChatEngine implements ChatEngine {
@@ -108,3 +113,59 @@ export class CondenseQuestionChatEngine implements ChatEngine {
     this.chatHistory = [];
   }
 }
+
+export class ContextChatEngine implements ChatEngine {
+  retriever: BaseRetriever;
+  chatModel: BaseChatModel;
+  chatHistory: BaseMessage[];
+
+  constructor(init: {
+    retriever: BaseRetriever;
+    chatModel?: BaseChatModel;
+    chatHistory?: BaseMessage[];
+  }) {
+    this.retriever = init.retriever;
+    this.chatModel = init.chatModel ?? new ChatOpenAI("gpt-3.5-turbo-16k");
+    this.chatHistory = init?.chatHistory ?? [];
+  }
+
+  chatRepl() {
+    throw new Error("Method not implemented.");
+  }
+
+  async achat(message: string, chatHistory?: BaseMessage[] | undefined) {
+    chatHistory = chatHistory ?? this.chatHistory;
+
+    const sourceNodesWithScore = await this.retriever.aretrieve(message);
+
+    const systemMessage: BaseMessage = {
+      content: contextSystemPrompt({
+        context: sourceNodesWithScore
+          .map((r) => (r.node as TextNode).text)
+          .join("\n\n"),
+      }),
+      type: "system",
+    };
+
+    chatHistory.push({ content: message, type: "human" });
+
+    const response = await this.chatModel.agenerate([
+      systemMessage,
+      ...chatHistory,
+    ]);
+    const text = response.generations[0][0].text;
+
+    chatHistory.push({ content: text, type: "ai" });
+
+    this.chatHistory = chatHistory;
+
+    return new Response(
+      text,
+      sourceNodesWithScore.map((r) => r.node)
+    );
+  }
+
+  reset() {
+    this.chatHistory = [];
+  }
+}
diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts
index d63e0bbaf..da9112311 100644
--- a/packages/core/src/LLMPredictor.ts
+++ b/packages/core/src/LLMPredictor.ts
@@ -1,6 +1,7 @@
 import { ChatOpenAI } from "./LanguageModel";
 import { SimplePrompt } from "./Prompt";
 
+// TODO change this to LLM class
 export interface BaseLLMPredictor {
   getLlmMetadata(): Promise<any>;
   apredict(
@@ -10,6 +11,7 @@ export interface BaseLLMPredictor {
   // stream(prompt: string, options: any): Promise<any>;
 }
 
+// TODO change this to LLM class
 export class ChatGPTLLMPredictor implements BaseLLMPredictor {
   llm: string;
   retryOnThrottling: boolean;
diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts
index 8d7ac9789..3cfb6fac4 100644
--- a/packages/core/src/Prompt.ts
+++ b/packages/core/src/Prompt.ts
@@ -308,3 +308,12 @@ export function messagesToHistoryStr(messages: BaseMessage[]) {
     return acc;
   }, "");
 }
+
+export const contextSystemPrompt: SimplePrompt = (input) => {
+  const { context } = input;
+
+  return `Context information is below.
+---------------------
+${context}
+---------------------`;
+};
diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts
index dfb72e88f..61928ff2c 100644
--- a/packages/core/src/Retriever.ts
+++ b/packages/core/src/Retriever.ts
@@ -8,7 +8,7 @@ import {
 } from "./storage/vectorStore/types";
 
 export interface BaseRetriever {
-  aretrieve(query: string): Promise<any>;
+  aretrieve(query: string): Promise<NodeWithScore[]>;
 }
 
 export class VectorIndexRetriever implements BaseRetriever {
-- 
GitLab