From 815a3416f29ec72f85a80c31b84cd6ed38bcb79c Mon Sep 17 00:00:00 2001
From: Yi Ding <yi.s.ding@gmail.com>
Date: Tue, 4 Jul 2023 20:53:36 -0700
Subject: [PATCH] more work

---
 packages/core/src/OutputParser.ts        | 74 ++++++++++++++++++++++
 packages/core/src/Prompt.ts              | 75 ++++++++++++++++++++++-
 packages/core/src/QueryEngine.ts         | 78 +++++++++++++++++++++---
 packages/core/src/QuestionGenerator.ts   | 10 +--
 packages/core/src/ResponseSynthesizer.ts |  4 +-
 packages/core/src/ServiceContext.ts      | 10 +--
 packages/core/src/Tool.ts                | 15 ++++-
 7 files changed, 244 insertions(+), 22 deletions(-)
 create mode 100644 packages/core/src/OutputParser.ts

diff --git a/packages/core/src/OutputParser.ts b/packages/core/src/OutputParser.ts
new file mode 100644
index 000000000..b7ab9bf31
--- /dev/null
+++ b/packages/core/src/OutputParser.ts
@@ -0,0 +1,74 @@
+import { SubQuestion } from "./QuestionGenerator";
+
+interface BaseOutputParser {
+  parse(output: string): any;
+  format(output: string): string;
+}
+
+interface StructuredOutput {
+  rawOutput: string;
+}
+
+class OutputParserError extends Error {
+  cause: Error | undefined;
+  output: string | undefined;
+
+  constructor(
+    message: string,
+    options: { cause?: Error; output?: string } = {}
+  ) {
+    // @ts-ignore
+    super(message, options); // https://github.com/tc39/proposal-error-cause
+    this.name = "OutputParserError";
+
+    if (!this.cause) {
+      // Need to check for those environments that have implemented the proposal
+      this.cause = options.cause;
+    }
+    this.output = options.output;
+
+    // This line is to maintain proper stack trace in V8
+    // (https://v8.dev/docs/stack-trace-api)
+    if (Error.captureStackTrace) {
+      Error.captureStackTrace(this, OutputParserError);
+    }
+  }
+}
+
+function parseJsonMarkdown(text: string) {
+  text = text.trim();
+
+  const beginDelimiter = "```json";
+  const endDelimiter = "```";
+
+  const beginIndex = text.indexOf(beginDelimiter);
+  const endIndex = text.indexOf(
+    endDelimiter,
+    beginIndex + beginDelimiter.length
+  );
+  if (beginIndex === -1 || endIndex === -1) {
+    throw new OutputParserError("Not a json markdown", { output: text });
+  }
+
+  const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex);
+
+  try {
+    return JSON.parse(jsonText);
+  } catch (e) {
+    throw new OutputParserError("Not a valid json", {
+      cause: e as Error,
+      output: text,
+    });
+  }
+}
+
+class SubQuestionOutputParser implements BaseOutputParser {
+  parse(output: string): SubQuestion[] {
+    const subQuestions = JSON.parse(output);
+    return subQuestions;
+  }
+
+  format(output: string): string {
+    return output;
+  }
+}
diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts
index ecfddcdc7..2410f0d2e 100644
--- a/packages/core/src/Prompt.ts
+++ b/packages/core/src/Prompt.ts
@@ -1,3 +1,6 @@
+import { SubQuestion } from "./QuestionGenerator";
+import { ToolMetadata } from "./Tool";
+
 /**
  * A SimplePrompt is a function that takes a dictionary of inputs and returns a string.
  * NOTE this is a different interface compared to LlamaIndex Python
@@ -190,6 +193,76 @@ SUFFIX = """\
 
 DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
 */
+
+export function buildToolsText(tools: ToolMetadata[]) {
+  const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => {
+    acc[tool.name] = tool.description;
+    return acc;
+  }, {});
+
+  return JSON.stringify(toolsObj, null, 4);
+}
+
+const exampleTools: ToolMetadata[] = [
+  {
+    name: "uber_10k",
+    description: "Provides information about Uber financials for year 2021",
+  },
+  {
+    name: "lyft_10k",
+    description: "Provides information about Lyft financials for year 2021",
+  },
+];
+
+const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`;
+
+const exampleOutput: SubQuestion[] = [
+  {
+    subQuestion: "What is the revenue growth of Uber",
+    toolName: "uber_10k",
+  },
+  {
+    subQuestion: "What is the EBITDA of Uber",
+    toolName: "uber_10k",
+  },
+  {
+    subQuestion: "What is the revenue growth of Lyft",
+    toolName: "lyft_10k",
+  },
+  {
+    subQuestion: "What is the EBITDA of Lyft",
+    toolName: "lyft_10k",
+  },
+];
+
 export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
-  return "";
+  const { toolsStr, queryStr } = input;
+
+  return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
+
+# Example 1
+<Tools>
+\`\`\`json
+${buildToolsText(exampleTools)}
+\`\`\`
+
+<User Question>
+${exampleQueryStr}
+
+<Output>
+\`\`\`json
+${JSON.stringify(exampleOutput, null, 4)}
+\`\`\`
+
+# Example 2
+<Tools>
+\`\`\`json
+${toolsStr}}
+\`\`\`
+
+<User Question>
+${queryStr}
+
+<Output>
+`;
 };
diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts
index 2300043e0..bddb0ea99 100644
--- a/packages/core/src/QueryEngine.ts
+++ b/packages/core/src/QueryEngine.ts
@@ -1,10 +1,14 @@
+import { NodeWithScore, TextNode } from "./Node";
 import {
   BaseQuestionGenerator,
   LLMQuestionGenerator,
+  SubQuestion,
 } from "./QuestionGenerator";
 import { Response } from "./Response";
-import { ResponseSynthesizer } from "./ResponseSynthesizer";
+import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
 import { BaseRetriever } from "./Retriever";
+import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
+import { QueryEngineTool, ToolMetadata } from "./Tool";
 
 export interface BaseQueryEngine {
   aquery(query: string): Promise<Response>;
@@ -27,16 +31,74 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
 
 export class SubQuestionQueryEngine implements BaseQueryEngine {
   responseSynthesizer: ResponseSynthesizer;
-  questionGenerator: BaseQuestionGenerator;
+  questionGen: BaseQuestionGenerator;
+  queryEngines: Record<string, BaseQueryEngine>;
+  metadatas: ToolMetadata[];
 
-  constructor(init?: Partial<SubQuestionQueryEngine>) {
+  constructor(init: {
+    questionGen: BaseQuestionGenerator;
+    responseSynthesizer: ResponseSynthesizer;
+    queryEngineTools: QueryEngineTool[];
+  }) {
+    this.questionGen = init.questionGen;
     this.responseSynthesizer =
-      init?.responseSynthesizer ?? new ResponseSynthesizer();
-    this.questionGenerator =
-      init?.questionGenerator ?? new LLMQuestionGenerator();
+      init.responseSynthesizer ?? new ResponseSynthesizer();
+    this.queryEngines = init.queryEngineTools.reduce<
+      Record<string, BaseQueryEngine>
+    >((acc, tool) => {
+      acc[tool.metadata.name] = tool.queryEngine;
+      return acc;
+    }, {});
+    this.metadatas = init.queryEngineTools.map((tool) => tool.metadata);
   }
 
-  aquery(query: string): Promise<Response> {
-    throw new Error("Method not implemented.");
+  static fromDefaults(init: {
+    queryEngineTools: QueryEngineTool[];
+    questionGen?: BaseQuestionGenerator;
+    responseSynthesizer?: ResponseSynthesizer;
+    serviceContext?: ServiceContext;
+  }) {
+    const serviceContext =
+      init.serviceContext ?? serviceContextFromDefaults({});
+
+    const questionGen = init.questionGen ?? new LLMQuestionGenerator();
+    const responseSynthesizer =
+      init.responseSynthesizer ??
+      new ResponseSynthesizer(new CompactAndRefine(serviceContext));
+
+    return new SubQuestionQueryEngine({
+      questionGen,
+      responseSynthesizer,
+      queryEngineTools: init.queryEngineTools,
+    });
+  }
+
+  async aquery(query: string): Promise<Response> {
+    const subQuestions = await this.questionGen.agenerate(
+      this.metadatas,
+      query
+    );
+    const subQNodes = await Promise.all(
+      subQuestions.map((subQ) => this.aquerySubQ(subQ))
+    );
+    const nodes = subQNodes
+      .filter((node) => node !== null)
+      .map((node) => node as NodeWithScore);
+    return this.responseSynthesizer.asynthesize(query, nodes);
+  }
+
+  private async aquerySubQ(subQ: SubQuestion): Promise<NodeWithScore | null> {
+    try {
+      const question = subQ.subQuestion;
+      const queryEngine = this.queryEngines[subQ.toolName];
+
+      const response = await queryEngine.aquery(question);
+      const responseText = response.response;
+      const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`;
+      const node = new TextNode({ text: nodeText });
+      return { node, score: 0 };
+    } catch (error) {
+      return null;
+    }
   }
 }
diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts
index 35689160b..be3b90360 100644
--- a/packages/core/src/QuestionGenerator.ts
+++ b/packages/core/src/QuestionGenerator.ts
@@ -2,6 +2,11 @@ import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
 import { SimplePrompt, defaultSubQuestionPrompt } from "./Prompt";
 import { ToolMetadata } from "./Tool";
 
+export interface SubQuestion {
+  subQuestion: string;
+  toolName: string;
+}
+
 export interface BaseQuestionGenerator {
   agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
 }
@@ -22,8 +27,3 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator {
     throw new Error("Method not implemented.");
   }
 }
-
-interface SubQuestion {
-  subQuestion: string;
-  toolName: string;
-}
diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts
index b2f347edd..f00cbf93f 100644
--- a/packages/core/src/ResponseSynthesizer.ts
+++ b/packages/core/src/ResponseSynthesizer.ts
@@ -186,8 +186,8 @@ export function getResponseBuilder(): BaseResponseBuilder {
 export class ResponseSynthesizer {
   responseBuilder: BaseResponseBuilder;
 
-  constructor() {
-    this.responseBuilder = getResponseBuilder();
+  constructor(responseBuilder?: BaseResponseBuilder) {
+    this.responseBuilder = responseBuilder ?? getResponseBuilder();
   }
 
   async asynthesize(query: string, nodes: NodeWithScore[]) {
diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts
index 2a635740e..dd72c771f 100644
--- a/packages/core/src/ServiceContext.ts
+++ b/packages/core/src/ServiceContext.ts
@@ -24,12 +24,12 @@ export interface ServiceContextOptions {
   chunkOverlap?: number;
 }
 
-export function serviceContextFromDefaults(options: ServiceContextOptions) {
+export function serviceContextFromDefaults(options?: ServiceContextOptions) {
   const serviceContext: ServiceContext = {
-    llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(),
-    embedModel: options.embedModel ?? new OpenAIEmbedding(),
-    nodeParser: options.nodeParser ?? new SimpleNodeParser(),
-    promptHelper: options.promptHelper ?? new PromptHelper(),
+    llmPredictor: options?.llmPredictor ?? new ChatGPTLLMPredictor(),
+    embedModel: options?.embedModel ?? new OpenAIEmbedding(),
+    nodeParser: options?.nodeParser ?? new SimpleNodeParser(),
+    promptHelper: options?.promptHelper ?? new PromptHelper(),
   };
 
   return serviceContext;
diff --git a/packages/core/src/Tool.ts b/packages/core/src/Tool.ts
index 221141e12..5eaecb125 100644
--- a/packages/core/src/Tool.ts
+++ b/packages/core/src/Tool.ts
@@ -1 +1,14 @@
-export interface ToolMetadata {}
+import { BaseQueryEngine } from "./QueryEngine";
+
+export interface ToolMetadata {
+  description: string;
+  name: string;
+}
+
+export interface BaseTool {
+  metadata: ToolMetadata;
+}
+
+export interface QueryEngineTool extends BaseTool {
+  queryEngine: BaseQueryEngine;
+}
-- 
GitLab