diff --git a/apps/simple/simple.txt b/apps/simple/simple.txt deleted file mode 100644 index 7cd89b8d0d51646af616cc9f784f61fcb1469bf1..0000000000000000000000000000000000000000 --- a/apps/simple/simple.txt +++ /dev/null @@ -1,9 +0,0 @@ -Simple flow: - -Get document list, in this case one document. -Split each document into nodes, in this case sentences or lines. -Embed each of the nodes and get vectors. Store them in memory for now. -Embed query. -Compare query with nodes and get the top n -Put the top n nodes into the prompt. -Execute prompt, get result. diff --git a/apps/simple/subquestion.ts b/apps/simple/subquestion.ts new file mode 100644 index 0000000000000000000000000000000000000000..a3a85273dd4ef8e812ef92fdf1756607324d4a92 --- /dev/null +++ b/apps/simple/subquestion.ts @@ -0,0 +1,60 @@ +// from llama_index import SimpleDirectoryReader, VectorStoreIndex +// from llama_index.query_engine import SubQuestionQueryEngine +// from llama_index.tools import QueryEngineTool, ToolMetadata + +// # load data +// pg_essay = SimpleDirectoryReader( +// input_dir="docs/examples/data/paul_graham/" +// ).load_data() + +// # build index and query engine +// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine() + +// # setup base query engine as tool +// query_engine_tools = [ +// QueryEngineTool( +// query_engine=query_engine, +// metadata=ToolMetadata( +// name="pg_essay", description="Paul Graham essay on What I Worked On" +// ), +// ) +// ] + +// query_engine = SubQuestionQueryEngine.from_defaults( +// query_engine_tools=query_engine_tools +// ) + +// response = query_engine.query( +// "How was Paul Grahams life different before and after YC?" +// ) + +// print(response) + +import { Document } from "@llamaindex/core/src/Node"; +import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; +import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine"; + +import essay from "./essay"; + +(async () => { + const document = new Document({ text: essay }); + const index = await VectorStoreIndex.fromDocuments([document]); + + const queryEngine = SubQuestionQueryEngine.fromDefaults({ + queryEngineTools: [ + { + queryEngine: index.asQueryEngine(), + metadata: { + name: "pg_essay", + description: "Paul Graham essay on What I Worked On", + }, + }, + ], + }); + + const response = await queryEngine.aquery( + "How was Paul Grahams life different before and after YC?" + ); + + console.log(response); +})(); diff --git a/packages/core/src/OutputParser.ts b/packages/core/src/OutputParser.ts new file mode 100644 index 0000000000000000000000000000000000000000..a5e0dded714507b64b8ff61de5d168673fd0fbd4 --- /dev/null +++ b/packages/core/src/OutputParser.ts @@ -0,0 +1,80 @@ +import { SubQuestion } from "./QuestionGenerator"; + +export interface BaseOutputParser<T> { + parse(output: string): T; + format(output: string): string; +} + +export interface StructuredOutput<T> { + rawOutput: string; + parsedOutput: T; +} + +class OutputParserError extends Error { + cause: Error | undefined; + output: string | undefined; + + constructor( + message: string, + options: { cause?: Error; output?: string } = {} + ) { + // @ts-ignore + super(message, options); // https://github.com/tc39/proposal-error-cause + this.name = "OutputParserError"; + + if (!this.cause) { + // Need to check for those environments that have implemented the proposal + this.cause = options.cause; + } + this.output = options.output; + + // This line is to maintain proper stack trace in V8 + // (https://v8.dev/docs/stack-trace-api) + if (Error.captureStackTrace) { + Error.captureStackTrace(this, OutputParserError); + } + } +} + +function parseJsonMarkdown(text: string) { + text = text.trim(); + + const beginDelimiter = "```json"; + const endDelimiter = "```"; + + const beginIndex = text.indexOf(beginDelimiter); + const endIndex = text.indexOf( + endDelimiter, + beginIndex + beginDelimiter.length + ); + if (beginIndex === -1 || endIndex === -1) { + throw new OutputParserError("Not a json markdown", { output: text }); + } + + const jsonText = text.substring(beginIndex + beginDelimiter.length, endIndex); + + try { + return JSON.parse(jsonText); + } catch (e) { + throw new OutputParserError("Not a valid json", { + cause: e as Error, + output: text, + }); + } +} + +export class SubQuestionOutputParser + implements BaseOutputParser<StructuredOutput<SubQuestion[]>> +{ + parse(output: string): StructuredOutput<SubQuestion[]> { + const parsed = parseJsonMarkdown(output); + + // TODO add zod validation + + return { rawOutput: output, parsedOutput: parsed }; + } + + format(output: string): string { + return output; + } +} diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 74d02c4bdbfac22fefd424dd78f1c889f4de820a..2410f0d2e800831e0f5fa20824421153b12f6ad4 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,3 +1,6 @@ +import { SubQuestion } from "./QuestionGenerator"; +import { ToolMetadata } from "./Tool"; + /** * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. * NOTE this is a different interface compared to LlamaIndex Python @@ -114,3 +117,152 @@ ${context} Question: ${query} Answer:`; }; + +/* +PREFIX = """\ +Given a user question, and a list of tools, output a list of relevant sub-questions \ +that when composed can help answer the full user question: + +""" + + +example_query_str = ( + "Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021" +) +example_tools = [ + ToolMetadata( + name="uber_10k", + description="Provides information about Uber financials for year 2021", + ), + ToolMetadata( + name="lyft_10k", + description="Provides information about Lyft financials for year 2021", + ), +] +example_tools_str = build_tools_text(example_tools) +example_output = [ + SubQuestion( + sub_question="What is the revenue growth of Uber", tool_name="uber_10k" + ), + SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"), + SubQuestion( + sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k" + ), + SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"), +] +example_output_str = json.dumps([x.dict() for x in example_output], indent=4) + +EXAMPLES = ( + """\ +# Example 1 +<Tools> +```json +{tools_str} +``` + +<User Question> +{query_str} + + +<Output> +```json +{output_str} +``` + +""".format( + query_str=example_query_str, + tools_str=example_tools_str, + output_str=example_output_str, + ) + .replace("{", "{{") + .replace("}", "}}") +) + +SUFFIX = """\ +# Example 2 +<Tools> +```json +{tools_str} +``` + +<User Question> +{query_str} + +<Output> +""" + +DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX +*/ + +export function buildToolsText(tools: ToolMetadata[]) { + const toolsObj = tools.reduce<Record<string, string>>((acc, tool) => { + acc[tool.name] = tool.description; + return acc; + }, {}); + + return JSON.stringify(toolsObj, null, 4); +} + +const exampleTools: ToolMetadata[] = [ + { + name: "uber_10k", + description: "Provides information about Uber financials for year 2021", + }, + { + name: "lyft_10k", + description: "Provides information about Lyft financials for year 2021", + }, +]; + +const exampleQueryStr = `Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021`; + +const exampleOutput: SubQuestion[] = [ + { + subQuestion: "What is the revenue growth of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the EBITDA of Uber", + toolName: "uber_10k", + }, + { + subQuestion: "What is the revenue growth of Lyft", + toolName: "lyft_10k", + }, + { + subQuestion: "What is the EBITDA of Lyft", + toolName: "lyft_10k", + }, +]; + +export const defaultSubQuestionPrompt: SimplePrompt = (input) => { + const { toolsStr, queryStr } = input; + + return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question: + +# Example 1 +<Tools> +\`\`\`json +${buildToolsText(exampleTools)} +\`\`\` + +<User Question> +${exampleQueryStr} + +<Output> +\`\`\`json +${JSON.stringify(exampleOutput, null, 4)} +\`\`\` + +# Example 2 +<Tools> +\`\`\`json +${toolsStr}} +\`\`\` + +<User Question> +${queryStr} + +<Output> +`; +}; diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index 10dace35954c1f89cb4e4662f1e607d5167d3a49..bddb0ea999cab222be0dc3c31df711ed09280e1a 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,12 +1,20 @@ +import { NodeWithScore, TextNode } from "./Node"; +import { + BaseQuestionGenerator, + LLMQuestionGenerator, + SubQuestion, +} from "./QuestionGenerator"; import { Response } from "./Response"; -import { ResponseSynthesizer } from "./ResponseSynthesizer"; +import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; +import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { QueryEngineTool, ToolMetadata } from "./Tool"; export interface BaseQueryEngine { aquery(query: string): Promise<Response>; } -export class RetrieverQueryEngine { +export class RetrieverQueryEngine implements BaseQueryEngine { retriever: BaseRetriever; responseSynthesizer: ResponseSynthesizer; @@ -20,3 +28,77 @@ export class RetrieverQueryEngine { return this.responseSynthesizer.asynthesize(query, nodes); } } + +export class SubQuestionQueryEngine implements BaseQueryEngine { + responseSynthesizer: ResponseSynthesizer; + questionGen: BaseQuestionGenerator; + queryEngines: Record<string, BaseQueryEngine>; + metadatas: ToolMetadata[]; + + constructor(init: { + questionGen: BaseQuestionGenerator; + responseSynthesizer: ResponseSynthesizer; + queryEngineTools: QueryEngineTool[]; + }) { + this.questionGen = init.questionGen; + this.responseSynthesizer = + init.responseSynthesizer ?? new ResponseSynthesizer(); + this.queryEngines = init.queryEngineTools.reduce< + Record<string, BaseQueryEngine> + >((acc, tool) => { + acc[tool.metadata.name] = tool.queryEngine; + return acc; + }, {}); + this.metadatas = init.queryEngineTools.map((tool) => tool.metadata); + } + + static fromDefaults(init: { + queryEngineTools: QueryEngineTool[]; + questionGen?: BaseQuestionGenerator; + responseSynthesizer?: ResponseSynthesizer; + serviceContext?: ServiceContext; + }) { + const serviceContext = + init.serviceContext ?? serviceContextFromDefaults({}); + + const questionGen = init.questionGen ?? new LLMQuestionGenerator(); + const responseSynthesizer = + init.responseSynthesizer ?? + new ResponseSynthesizer(new CompactAndRefine(serviceContext)); + + return new SubQuestionQueryEngine({ + questionGen, + responseSynthesizer, + queryEngineTools: init.queryEngineTools, + }); + } + + async aquery(query: string): Promise<Response> { + const subQuestions = await this.questionGen.agenerate( + this.metadatas, + query + ); + const subQNodes = await Promise.all( + subQuestions.map((subQ) => this.aquerySubQ(subQ)) + ); + const nodes = subQNodes + .filter((node) => node !== null) + .map((node) => node as NodeWithScore); + return this.responseSynthesizer.asynthesize(query, nodes); + } + + private async aquerySubQ(subQ: SubQuestion): Promise<NodeWithScore | null> { + try { + const question = subQ.subQuestion; + const queryEngine = this.queryEngines[subQ.toolName]; + + const response = await queryEngine.aquery(question); + const responseText = response.response; + const nodeText = `Sub question: ${question}\nResponse: ${responseText}}`; + const node = new TextNode({ text: nodeText }); + return { node, score: 0 }; + } catch (error) { + return null; + } + } +} diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts new file mode 100644 index 0000000000000000000000000000000000000000..fad1f9732c952095034e7b07ba22ea6e03b2ebf0 --- /dev/null +++ b/packages/core/src/QuestionGenerator.ts @@ -0,0 +1,49 @@ +import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; +import { + BaseOutputParser, + StructuredOutput, + SubQuestionOutputParser, +} from "./OutputParser"; +import { + SimplePrompt, + buildToolsText, + defaultSubQuestionPrompt, +} from "./Prompt"; +import { ToolMetadata } from "./Tool"; + +export interface SubQuestion { + subQuestion: string; + toolName: string; +} + +export interface BaseQuestionGenerator { + agenerate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; +} + +export class LLMQuestionGenerator implements BaseQuestionGenerator { + llmPredictor: BaseLLMPredictor; + prompt: SimplePrompt; + outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>; + + constructor(init?: Partial<LLMQuestionGenerator>) { + this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor(); + this.prompt = init?.prompt ?? defaultSubQuestionPrompt; + this.outputParser = init?.outputParser ?? new SubQuestionOutputParser(); + } + + async agenerate( + tools: ToolMetadata[], + query: string + ): Promise<SubQuestion[]> { + const toolsStr = buildToolsText(tools); + const queryStr = query; + const prediction = await this.llmPredictor.apredict(this.prompt, { + toolsStr, + queryStr, + }); + + const structuredOutput = this.outputParser.parse(prediction); + + return structuredOutput.parsedOutput; + } +} diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index b2f347eddabe7fed57355d249e33986dd61ec7f8..f00cbf93fb6f4f784e1a6b3b74aac6ec24ed3383 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -186,8 +186,8 @@ export function getResponseBuilder(): BaseResponseBuilder { export class ResponseSynthesizer { responseBuilder: BaseResponseBuilder; - constructor() { - this.responseBuilder = getResponseBuilder(); + constructor(responseBuilder?: BaseResponseBuilder) { + this.responseBuilder = responseBuilder ?? getResponseBuilder(); } async asynthesize(query: string, nodes: NodeWithScore[]) { diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 2a635740ebb2d9d7fe3e61df1fd453915e98bd54..dd72c771f2767715637a0cd29c53f3382ce89991 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -24,12 +24,12 @@ export interface ServiceContextOptions { chunkOverlap?: number; } -export function serviceContextFromDefaults(options: ServiceContextOptions) { +export function serviceContextFromDefaults(options?: ServiceContextOptions) { const serviceContext: ServiceContext = { - llmPredictor: options.llmPredictor ?? new ChatGPTLLMPredictor(), - embedModel: options.embedModel ?? new OpenAIEmbedding(), - nodeParser: options.nodeParser ?? new SimpleNodeParser(), - promptHelper: options.promptHelper ?? new PromptHelper(), + llmPredictor: options?.llmPredictor ?? new ChatGPTLLMPredictor(), + embedModel: options?.embedModel ?? new OpenAIEmbedding(), + nodeParser: options?.nodeParser ?? new SimpleNodeParser(), + promptHelper: options?.promptHelper ?? new PromptHelper(), }; return serviceContext; diff --git a/packages/core/src/Tool.ts b/packages/core/src/Tool.ts new file mode 100644 index 0000000000000000000000000000000000000000..5eaecb125cacd0e45848607229eadbdb9d4ed48b --- /dev/null +++ b/packages/core/src/Tool.ts @@ -0,0 +1,14 @@ +import { BaseQueryEngine } from "./QueryEngine"; + +export interface ToolMetadata { + description: string; + name: string; +} + +export interface BaseTool { + metadata: ToolMetadata; +} + +export interface QueryEngineTool extends BaseTool { + queryEngine: BaseQueryEngine; +}