Skip to content
Snippets Groups Projects
Commit 0c881c8f authored by Yi Ding's avatar Yi Ding
Browse files

finished subquestion demo

parent 815a3416
No related branches found
No related tags found
No related merge requests found
Simple flow:
Get document list, in this case one document.
Split each document into nodes, in this case sentences or lines.
Embed each of the nodes and get vectors. Store them in memory for now.
Embed query.
Compare query with nodes and get the top n
Put the top n nodes into the prompt.
Execute prompt, get result.
// from llama_index import SimpleDirectoryReader, VectorStoreIndex
// from llama_index.query_engine import SubQuestionQueryEngine
// from llama_index.tools import QueryEngineTool, ToolMetadata
// # load data
// pg_essay = SimpleDirectoryReader(
// input_dir="docs/examples/data/paul_graham/"
// ).load_data()
// # build index and query engine
// query_engine = VectorStoreIndex.from_documents(pg_essay).as_query_engine()
// # setup base query engine as tool
// query_engine_tools = [
// QueryEngineTool(
// query_engine=query_engine,
// metadata=ToolMetadata(
// name="pg_essay", description="Paul Graham essay on What I Worked On"
// ),
// )
// ]
// query_engine = SubQuestionQueryEngine.from_defaults(
// query_engine_tools=query_engine_tools
// )
// response = query_engine.query(
// "How was Paul Grahams life different before and after YC?"
// )
// print(response)
import { Document } from "@llamaindex/core/src/Node";
import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex";
import { SubQuestionQueryEngine } from "@llamaindex/core/src/QueryEngine";
import essay from "./essay";
(async () => {
const document = new Document({ text: essay });
const index = await VectorStoreIndex.fromDocuments([document]);
const queryEngine = SubQuestionQueryEngine.fromDefaults({
queryEngineTools: [
{
queryEngine: index.asQueryEngine(),
metadata: {
name: "pg_essay",
description: "Paul Graham essay on What I Worked On",
},
},
],
});
const response = await queryEngine.aquery(
"How was Paul Grahams life different before and after YC?"
);
console.log(response);
})();
import { SubQuestion } from "./QuestionGenerator"; import { SubQuestion } from "./QuestionGenerator";
interface BaseOutputParser { export interface BaseOutputParser<T> {
parse(output: string): any; parse(output: string): T;
format(output: string): string; format(output: string): string;
} }
interface StructuredOutput { export interface StructuredOutput<T> {
rawOutput: string; rawOutput: string;
parsedOutput: T;
} }
class OutputParserError extends Error { class OutputParserError extends Error {
...@@ -62,10 +63,15 @@ function parseJsonMarkdown(text: string) { ...@@ -62,10 +63,15 @@ function parseJsonMarkdown(text: string) {
} }
} }
class SubQuestionOutputParser implements BaseOutputParser { export class SubQuestionOutputParser
parse(output: string): SubQuestion[] { implements BaseOutputParser<StructuredOutput<SubQuestion[]>>
const subQuestions = JSON.parse(output); {
return subQuestions; parse(output: string): StructuredOutput<SubQuestion[]> {
const parsed = parseJsonMarkdown(output);
// TODO add zod validation
return { rawOutput: output, parsedOutput: parsed };
} }
format(output: string): string { format(output: string): string {
......
import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor"; import { BaseLLMPredictor, ChatGPTLLMPredictor } from "./LLMPredictor";
import { SimplePrompt, defaultSubQuestionPrompt } from "./Prompt"; import {
BaseOutputParser,
StructuredOutput,
SubQuestionOutputParser,
} from "./OutputParser";
import {
SimplePrompt,
buildToolsText,
defaultSubQuestionPrompt,
} from "./Prompt";
import { ToolMetadata } from "./Tool"; import { ToolMetadata } from "./Tool";
export interface SubQuestion { export interface SubQuestion {
...@@ -14,16 +23,27 @@ export interface BaseQuestionGenerator { ...@@ -14,16 +23,27 @@ export interface BaseQuestionGenerator {
export class LLMQuestionGenerator implements BaseQuestionGenerator { export class LLMQuestionGenerator implements BaseQuestionGenerator {
llmPredictor: BaseLLMPredictor; llmPredictor: BaseLLMPredictor;
prompt: SimplePrompt; prompt: SimplePrompt;
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
constructor(init?: Partial<LLMQuestionGenerator>) { constructor(init?: Partial<LLMQuestionGenerator>) {
this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor(); this.llmPredictor = init?.llmPredictor ?? new ChatGPTLLMPredictor();
this.prompt = init?.prompt ?? defaultSubQuestionPrompt; this.prompt = init?.prompt ?? defaultSubQuestionPrompt;
this.outputParser = init?.outputParser ?? new SubQuestionOutputParser();
} }
async agenerate( async agenerate(
tools: ToolMetadata[], tools: ToolMetadata[],
query: string query: string
): Promise<SubQuestion[]> { ): Promise<SubQuestion[]> {
throw new Error("Method not implemented."); const toolsStr = buildToolsText(tools);
const queryStr = query;
const prediction = await this.llmPredictor.apredict(this.prompt, {
toolsStr,
queryStr,
});
const structuredOutput = this.outputParser.parse(prediction);
return structuredOutput.parsedOutput;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment