From 815a3416f29ec72f85a80c31b84cd6ed38bcb79c Mon Sep 17 00:00:00 2001 From: Yi Ding <yi.s.ding@gmail.com> Date: Tue, 4 Jul 2023 20:53:36 -0700 Subject: [PATCH] more work --- packages/core/src/OutputParser.ts | 74 ++++++++++++++++++++++ packages/core/src/Prompt.ts | 75 ++++++++++++++++++++++- packages/core/src/QueryEngine.ts | 78 +++++++++++++++++++++--- packages/core/src/QuestionGenerator.ts | 10 +-- packages/core/src/ResponseSynthesizer.ts | 4 +- packages/core/src/ServiceContext.ts | 10 +-- packages/core/src/Tool.ts | 15 ++++- 7 files changed, 244 insertions(+), 22 deletions(-) create mode 100644 packages/core/src/OutputParser.ts diff --git a/packages/core/src/OutputParser.ts b/packages/core/src/OutputParser.ts new file mode 100644 index 000000000..b7ab9bf31 --- /dev/null +++ b/packages/core/src/OutputParser.ts @@ -0,0 +1,74 @@ +import { SubQuestion } from "./QuestionGenerator"; + +interface BaseOutputParser { + parse(output: string): any; + format(output: string): string; +} + +interface StructuredOutput { + rawOutput: string; +} + +class OutputParserError extends Error { + cause: Error | undefined; + output: string | undefined; + + constructor( + message: string, + options: { cause?: Error; output?: string } = {} + ) { + // @ts-ignore + super(message, options); // https://github.com/tc39/proposal-error-cause + this.name = "OutputParserError"; + + if (!this.cause) { + // Need to check for those environments that have implemented the proposal + this.cause = options.cause; + } + this.output = options.output; + + // This line is to maintain proper stack trace in V8 + // (https://v8.dev/docs/stack-trace-api) + if (Error.captureStackTrace) { + Error.captureStackTrace(this, OutputParserError); + } + } +} + +function parseJsonMarkdown(text: string) { + text = text.trim(); + + const beginDelimiter = "```json"; + const endDelimiter = "```"; + + const beginIndex = text.indexOf(beginDelimiter); + const endIndex = text.indexOf( + endDelimiter, + beginIndex + beginDelimiter.length + ); + if (beginIndex === -1 || endIndex === -1) { + throw new OutputParserError("Not a json markdown", { output: text }); + } + + const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex); + + try { + return JSON.parse(jsonText); + } catch (e) { + throw new OutputParserError("Not a valid json", { + cause: e as Error, + output: text, + }); + } +} + +class SubQuestionOutputParser implements BaseOutputParser { + parse(output: string): SubQuestion[] { + const subQuestions = JSON.parse(output); + return subQuestions; + } + + format(output: string): string { + return output; + } +} diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index ecfddcdc7..2410f0d2e 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,3 +1,6 @@ +import { SubQuestion } from "./QuestionGenerator"; +import { ToolMetadata } from "./Tool"; + /** * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. * NOTE this is a different interface compared to LlamaIndex Python @@ -190,6 +193,76 @@ SUFFIX = """\ DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX */ + +export function buildToolsText(tools: ToolMetadata[]) { + const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { + acc[tool.name] = tool.description; + return acc; + }, {}); + + return JSON.stringify(toolsObj, null, 4); +} + +const exampleTools: ToolMetadata[] = [ + { + name: "uber_10k", + description: "Provides information about Uber financials for year 2021", + }, + { + name: "lyft_10k", + description: "Provides information about Lyft financials for year 2021", + }, +]; + +const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`; + +const exampleOutput: SubQuestion[] = [ + { + subQuestion: "What is the revenue growth of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the EBITDA of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the revenue growth of Lyft", + toolName: "lyft_10k", + }, + { + subQuestion: "What is the EBITDA of Lyft", + toolName: "lyft_10k", + }, +]; + export const defaultSubQuestionPrompt: SimplePrompt = (input) => { - return ""; + const { toolsStr, queryStr } = input; + + return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: + +# Example 1 +<Tools> +\`\`\`json +${buildToolsText(exampleTools)} +\`\`\` + +<User Question> +${exampleQueryStr} + +<Output> +\`\`\`json +${JSON.stringify(exampleOutput, null, 4)} +\`\`\` + +# Example 2 +<Tools> +\`\`\`json +${toolsStr}} +\`\`\` + +<User Question> +${queryStr} + +<Output> +`; }; diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index 2300043e0..bddb0ea99 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,10 +1,14 @@ +import { NodeWithScore, TextNode } from "./Node"; import { BaseQuestionGenerator, LLMQuestionGenerator, + SubQuestion, } from "./QuestionGenerator"; import { Response } from "./Response"; -import { ResponseSynthesizer } from "./ResponseSynthesizer"; +import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; +import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { QueryEngineTool, ToolMetadata } from "./Tool"; export interface BaseQueryEngine { aquery(query: string): Promise<Response>; @@ -27,16 +31,74 @@ export class RetrieverQueryEngine implements BaseQueryEngine { export class SubQuestionQueryEngine implements BaseQueryEngine { responseSynthesizer: ResponseSynthesizer; - questionGenerator: BaseQuestionGenerator; + questionGen: BaseQuestionGenerator; + queryEngines: Record<string, BaseQueryEngine>; + metadatas: ToolMetadata[]; - constructor(init?: Partial<SubQuestionQueryEngine>) { + constructor(init: { + questionGen: BaseQuestionGenerator; + responseSynthesizer: ResponseSynthesizer; + queryEngineTools: QueryEngineTool[]; + }) { + this.questionGen = init.questionGen; this.responseSynthesizer = - init?.responseSynthesizer ?? new ResponseSynthesizer(); - this.questionGenerator = - init?.questionGenerator ?? new LLMQuestionGenerator(); + init.responseSynthesizer ?? new ResponseSynthesizer(); + this.queryEngines = init.queryEngineTools.reduce< + Record<string, BaseQueryEngine> + >((acc, tool) => { + acc[tool.metadata.name] = tool.queryEngine; + return acc; + }, {}); + this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); } - aquery(query: string): Promise<Response> { - throw new Error("Method not implemented."); + static fromDefaults(init: { + queryEngineTools: QueryEngineTool[]; + questionGen?: BaseQuestionGenerator; + responseSynthesizer?: ResponseSynthesizer; + serviceContext?: ServiceContext; + }) { + const serviceContext = + init.serviceContext ?? serviceContextFromDefaults({}); + + const questionGen = init.questionGen ?? new LLMQuestionGenerator(); + const responseSynthesizer = + init.responseSynthesizer ?? + new ResponseSynthesizer(new CompactAndRefine(serviceContext)); + + return new SubQuestionQueryEngine({ + questionGen, + responseSynthesizer, + queryEngineTools: init.queryEngineTools, + }); + } + + async aquery(query: string): Promise<Response> { + const subQuestions = await this.questionGen.agenerate( + this.metadatas, + query + ); + const subQNodes = await Promise.all( + subQuestions.map((subQ) => this.aquerySubQ(subQ)) + ); + const nodes = subQNodes + .filter((node) => node !== null) + .map((node) => node as NodeWithScore); + return this.responseSynthesizer.asynthesize(query, nodes); + } + + private async aquerySubQ(subQ: SubQuestion): Promise<NodeWithScore | null> { + try { + const question = subQ.subQuestion; + const queryEngine = this.queryEngines[subQ.toolName]; + + const response = await queryEngine.aquery(question); + const responseText = response.response; + const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`; + const node = new TextNode({ text: nodeText }); + return { node, score: 0 }; + } catch (error) { + return null; + } } } diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index 35689160b..be3b90360 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -2,6 +2,11 @@ import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; import { SimplePrompt, defaultSubQuestionPrompt } from "./Prompt"; import { ToolMetadata } from "./Tool"; +export interface SubQuestion { + subQuestion: string; + toolName: string; +} + export interface BaseQuestionGenerator { agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; } @@ -22,8 +27,3 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator { throw new Error("Method not implemented."); } } - -interface SubQuestion { - subQuestion: string; - toolName: string; -} diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index b2f347edd..f00cbf93f 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -186,8 +186,8 @@ export function getResponseBuilder(): BaseResponseBuilder { export class ResponseSynthesizer { responseBuilder: BaseResponseBuilder; - constructor() { - this.responseBuilder = getResponseBuilder(); + constructor(responseBuilder?: BaseResponseBuilder) { + this.responseBuilder = responseBuilder ?? getResponseBuilder(); } async asynthesize(query: string, nodes: NodeWithScore[]) { diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 2a635740e..dd72c771f 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -24,12 +24,12 @@ export interface ServiceContextOptions { chunkOverlap?: number; } -export function serviceContextFromDefaults(options: ServiceContextOptions) { +export function serviceContextFromDefaults(options?: ServiceContextOptions) { const serviceContext: ServiceContext = { - llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(), - embedModel: options.embedModel ?? new OpenAIEmbedding(), - nodeParser: options.nodeParser ?? new SimpleNodeParser(), - promptHelper: options.promptHelper ?? new PromptHelper(), + llmPredictor: options?.llmPredictor ?? new ChatGPTLLMPredictor(), + embedModel: options?.embedModel ?? new OpenAIEmbedding(), + nodeParser: options?.nodeParser ?? new SimpleNodeParser(), + promptHelper: options?.promptHelper ?? new PromptHelper(), }; return serviceContext; diff --git a/packages/core/src/Tool.ts b/packages/core/src/Tool.ts index 221141e12..5eaecb125 100644 --- a/packages/core/src/Tool.ts +++ b/packages/core/src/Tool.ts @@ -1 +1,14 @@ -export interface ToolMetadata {} +import { BaseQueryEngine } from "./QueryEngine"; + +export interface ToolMetadata { + description: string; + name: string; +} + +export interface BaseTool { + metadata: ToolMetadata; +} + +export interface QueryEngineTool extends BaseTool { + queryEngine: BaseQueryEngine; +} -- GitLab