From 5f9f81378cc973f854484df4f1335c20cfe82c49 Mon Sep 17 00:00:00 2001
From: Yi Ding <yi.s.ding@gmail.com>
Date: Mon, 26 Jun 2023 08:10:45 -0700
Subject: [PATCH] add tree summarize

---
 packages/core/src/LLMPredictor.ts        | 28 ++++++---
 packages/core/src/LanguageModel.ts       |  8 +--
 packages/core/src/Prompt.ts              |  7 ++-
 packages/core/src/PromptHelper.ts        | 14 +++++
 packages/core/src/ResponseSynthesizer.ts | 78 +++++++++++++++++++++++-
 packages/core/src/constants.ts           |  1 +
 6 files changed, 119 insertions(+), 17 deletions(-)
 create mode 100644 packages/core/src/PromptHelper.ts

diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts
index 9df10ce6b..ac54bc1aa 100644
--- a/packages/core/src/LLMPredictor.ts
+++ b/packages/core/src/LLMPredictor.ts
@@ -1,8 +1,12 @@
 import { ChatOpenAI } from "./LanguageModel";
+import { SimplePrompt } from "./Prompt";
 
 export interface BaseLLMPredictor {
   getLlmMetadata(): Promise<any>;
-  apredict(prompt: string, options: any): Promise<string>;
+  apredict(
+    prompt: string | SimplePrompt,
+    input?: { [key: string]: string }
+  ): Promise<string>;
   // stream(prompt: string, options: any): Promise<any>;
 }
 
@@ -25,13 +29,21 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor {
     throw new Error("Not implemented yet");
   }
 
-  async apredict(prompt: string, options: any) {
-    return this.languageModel.agenerate([
-      {
-        content: prompt,
-        type: "human",
-      },
-    ]);
+  async apredict(
+    prompt: string | SimplePrompt,
+    input?: { [key: string]: string }
+  ): Promise<string> {
+    if (typeof prompt === "string") {
+      const result = await this.languageModel.agenerate([
+        {
+          content: prompt,
+          type: "human",
+        },
+      ]);
+      return result.generations[0][0].text;
+    } else {
+      return this.apredict(prompt(input ?? {}));
+    }
   }
 
   // async stream(prompt: string, options: any) {
diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LanguageModel.ts
index cf0d1307f..8862e2fd9 100644
--- a/packages/core/src/LanguageModel.ts
+++ b/packages/core/src/LanguageModel.ts
@@ -6,8 +6,6 @@ import {
   getOpenAISession,
 } from "./openai";
 
-interface LLMResult {}
-
 export interface BaseLanguageModel {}
 
 type MessageType = "human" | "ai" | "system" | "generic" | "function";
@@ -22,7 +20,7 @@ interface Generation {
   generationInfo?: { [key: string]: any };
 }
 
-interface LLMResult {
+export interface LLMResult {
   generations: Generation[][]; // Each input can have more than one generations
 }
 
@@ -62,7 +60,7 @@ export class ChatOpenAI extends BaseChatModel {
     }
   }
 
-  async agenerate(messages: BaseMessage[]) {
+  async agenerate(messages: BaseMessage[]): Promise<LLMResult> {
     const { data } = await this.session.openai.createChatCompletion({
       model: this.model,
       temperature: this.temperature,
@@ -75,6 +73,6 @@ export class ChatOpenAI extends BaseChatModel {
     });
 
     const content = data.choices[0].message?.content ?? "";
-    return content;
+    return { generations: [[{ text: content }]] };
   }
 }
diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts
index 2d90617b3..8d1db56cf 100644
--- a/packages/core/src/Prompt.ts
+++ b/packages/core/src/Prompt.ts
@@ -1,6 +1,7 @@
 /**
  * A SimplePrompt is a function that takes a dictionary of inputs and returns a string.
  * NOTE this is a different interface compared to LlamaIndex Python
+ * NOTE 2: we default to empty string to make it easy to calculate prompt sizes
  */
 export type SimplePrompt = (input: { [key: string]: string }) => string;
 
@@ -16,7 +17,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = (
 */
 
 export const defaultTextQaPrompt: SimplePrompt = (input) => {
-  const { context, query } = input;
+  const { context = "", query = "" } = input;
 
   return `Context information is below.
 ---------------------
@@ -41,7 +42,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = (
 */
 
 export const defaultSummaryPrompt: SimplePrompt = (input) => {
-  const { context } = input;
+  const { context = "" } = input;
 
   return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible.
 
@@ -69,7 +70,7 @@ DEFAULT_REFINE_PROMPT_TMPL = (
 */
 
 export const defaultRefinePrompt: SimplePrompt = (input) => {
-  const { query, existingAnswer, context } = input;
+  const { query = "", existingAnswer = "", context = "" } = input;
 
   return `The original question is as follows: ${query}
 We have provided an existing answer: ${existingAnswer}
diff --git a/packages/core/src/PromptHelper.ts b/packages/core/src/PromptHelper.ts
new file mode 100644
index 000000000..eb797f0c4
--- /dev/null
+++ b/packages/core/src/PromptHelper.ts
@@ -0,0 +1,14 @@
+import {
+  DEFAULT_CONTEXT_WINDOW,
+  DEFAULT_NUM_OUTPUTS,
+  DEFAULT_CHUNK_OVERLAP_RATIO,
+} from "./constants";
+
+class PromptHelper {
+  contextWindow = DEFAULT_CONTEXT_WINDOW;
+  numOutput = DEFAULT_NUM_OUTPUTS;
+  chunkOverlapRatio = DEFAULT_CHUNK_OVERLAP_RATIO;
+  chunkSizeLimit?: number;
+  tokenizer?: (text: string) => string[];
+  separator = " ";
+}
diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts
index b7e23f800..d0c1e2855 100644
--- a/packages/core/src/ResponseSynthesizer.ts
+++ b/packages/core/src/ResponseSynthesizer.ts
@@ -2,12 +2,13 @@ import { ChatGPTLLMPredictor } from "./LLMPredictor";
 import { NodeWithScore } from "./Node";
 import { SimplePrompt, defaultTextQaPrompt } from "./Prompt";
 import { Response } from "./Response";
+import { ServiceContext } from "./ServiceContext";
 
 interface BaseResponseBuilder {
   agetResponse(query: string, textChunks: string[]): Promise<string>;
 }
 
-export class SimpleResponseBuilder {
+export class SimpleResponseBuilder implements BaseResponseBuilder {
   llmPredictor: ChatGPTLLMPredictor;
   textQATemplate: SimplePrompt;
 
@@ -27,6 +28,81 @@ export class SimpleResponseBuilder {
   }
 }
 
+export class Refine implements BaseResponseBuilder {
+  async agetResponse(
+    query: string,
+    textChunks: string[],
+    prevResponse?: any
+  ): Promise<string> {
+    throw new Error("Not implemented yet");
+  }
+
+  private giveResponseSingle(queryStr: string, textChunk: string) {
+    const textQATemplate = defaultTextQaPrompt;
+  }
+
+  private refineResponseSingle(
+    response: string,
+    queryStr: string,
+    textChunk: string
+  ) {
+    throw new Error("Not implemented yet");
+  }
+}
+export class CompactAndRefine extends Refine {
+  async agetResponse(
+    query: string,
+    textChunks: string[],
+    prevResponse?: any
+  ): Promise<string> {
+    throw new Error("Not implemented yet");
+  }
+}
+
+export class TreeSummarize implements BaseResponseBuilder {
+  serviceContext: ServiceContext;
+
+  constructor(serviceContext: ServiceContext) {
+    this.serviceContext = serviceContext;
+  }
+
+  async agetResponse(query: string, textChunks: string[]): Promise<string> {
+    const summaryTemplate: SimplePrompt = (input) =>
+      defaultTextQaPrompt({ ...input, query: query });
+
+    if (!textChunks || textChunks.length === 0) {
+      throw new Error("Must have at least one text chunk");
+    }
+
+    // TODO repack more intelligently
+    // Combine text chunks in pairs into packedTextChunks
+    let packedTextChunks: string[] = [];
+    for (let i = 0; i < textChunks.length; i += 2) {
+      if (i + 1 < textChunks.length) {
+        packedTextChunks.push(textChunks[i] + "\n\n" + textChunks[i + 1]);
+      } else {
+        packedTextChunks.push(textChunks[i]);
+      }
+    }
+
+    if (packedTextChunks.length === 1) {
+      return this.serviceContext.llmPredictor.apredict(summaryTemplate, {
+        context: packedTextChunks[0],
+      });
+    } else {
+      const summaries = await Promise.all(
+        packedTextChunks.map((chunk) =>
+          this.serviceContext.llmPredictor.apredict(summaryTemplate, {
+            context: chunk,
+          })
+        )
+      );
+
+      return this.agetResponse(query, summaries);
+    }
+  }
+}
+
 export function getResponseBuilder(): BaseResponseBuilder {
   return new SimpleResponseBuilder();
 }
diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts
index def343a79..2c3a3c92d 100644
--- a/packages/core/src/constants.ts
+++ b/packages/core/src/constants.ts
@@ -3,6 +3,7 @@ export const DEFAULT_NUM_OUTPUTS = 256;
 
 export const DEFAULT_CHUNK_SIZE = 1024;
 export const DEFAULT_CHUNK_OVERLAP = 20;
+export const DEFAULT_CHUNK_OVERLAP_RATIO = 0.1;
 export const DEFAULT_SIMILARITY_TOP_K = 2;
 
 // NOTE: for text-embedding-ada-002
-- 
GitLab