diff --git a/apps/docs/docs/modules/query_engines/_category_.yml b/apps/docs/docs/modules/query_engines/_category_.yml new file mode 100644 index 0000000000000000000000000000000000000000..fe92bdfe6fbe176663bbfd79109e179d067bc573 --- /dev/null +++ b/apps/docs/docs/modules/query_engines/_category_.yml @@ -0,0 +1,3 @@ +label: "Query Engines" +collapsed: false +position: 2 diff --git a/apps/docs/docs/modules/query_engine.md b/apps/docs/docs/modules/query_engines/index.md similarity index 98% rename from apps/docs/docs/modules/query_engine.md rename to apps/docs/docs/modules/query_engines/index.md index 65cc742f0ac13711a3580ae4f18af7e4af6eb878..83d6e875bf3b7c4253f854d090e045f2f87d8eb9 100644 --- a/apps/docs/docs/modules/query_engine.md +++ b/apps/docs/docs/modules/query_engines/index.md @@ -1,7 +1,3 @@ ---- -sidebar_position: 4 ---- - # QueryEngine A query engine wraps a `Retriever` and a `ResponseSynthesizer` into a pipeline, that will use the query string to fetech nodes and then send them to the LLM to generate a response. diff --git a/apps/docs/docs/modules/query_engines/router_query_engine.md b/apps/docs/docs/modules/query_engines/router_query_engine.md new file mode 100644 index 0000000000000000000000000000000000000000..c4d045c4077d827cc4732ff0b6096ac49035cc87 --- /dev/null +++ b/apps/docs/docs/modules/query_engines/router_query_engine.md @@ -0,0 +1,189 @@ +# Router Query Engine + +In this tutorial, we define a custom router query engine that selects one out of several candidate query engines to execute a query. + +## Setup + +First, we need to install import the necessary modules from `llamaindex`: + +```bash +pnpm i lamaindex +``` + +```ts +import { + OpenAI, + RouterQueryEngine, + SimpleDirectoryReader, + SimpleNodeParser, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; +``` + +## Loading Data + +Next, we need to load some data. We will use the `SimpleDirectoryReader` to load documents from a directory: + +```ts +const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples", +}); +``` + +## Service Context + +Next, we need to define some basic rules and parse the documents into nodes. We will use the `SimpleNodeParser` to parse the documents into nodes and `ServiceContext` to define the rules (eg. LLM API key, chunk size, etc.): + +```ts +const nodeParser = new SimpleNodeParser({ + chunkSize: 1024, +}); + +const serviceContext = serviceContextFromDefaults({ + nodeParser, + llm: new OpenAI(), +}); +``` + +## Creating Indices + +Next, we need to create some indices. We will create a `VectorStoreIndex` and a `SummaryIndex`: + +```ts +const vectorIndex = await VectorStoreIndex.fromDocuments(documents, { + serviceContext, +}); + +const summaryIndex = await SummaryIndex.fromDocuments(documents, { + serviceContext, +}); +``` + +## Creating Query Engines + +Next, we need to create some query engines. We will create a `VectorStoreQueryEngine` and a `SummaryQueryEngine`: + +```ts +const vectorQueryEngine = vectorIndex.asQueryEngine(); +const summaryQueryEngine = summaryIndex.asQueryEngine(); +``` + +## Creating a Router Query Engine + +Next, we need to create a router query engine. We will use the `RouterQueryEngine` to create a router query engine: + +We're defining two query engines, one for summarization and one for retrieving specific context. The router query engine will select the most appropriate query engine based on the query. + +```ts +const queryEngine = RouterQueryEngine.fromDefaults({ + queryEngineTools: [ + { + queryEngine: vectorQueryEngine, + description: "Useful for summarization questions related to Abramov", + }, + { + queryEngine: summaryQueryEngine, + description: "Useful for retrieving specific context from Abramov", + }, + ], + serviceContext, +}); +``` + +## Querying the Router Query Engine + +Finally, we can query the router query engine: + +```ts +const summaryResponse = await queryEngine.query({ + query: "Give me a summary about his past experiences?", +}); + +console.log({ + answer: summaryResponse.response, + metadata: summaryResponse?.metadata?.selectorResult, +}); +``` + +## Full code + +```ts +import { + OpenAI, + RouterQueryEngine, + SimpleDirectoryReader, + SimpleNodeParser, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +async function main() { + // Load documents from a directory + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples", + }); + + // Parse the documents into nodes + const nodeParser = new SimpleNodeParser({ + chunkSize: 1024, + }); + + // Create a service context + const serviceContext = serviceContextFromDefaults({ + nodeParser, + llm: new OpenAI(), + }); + + // Create indices + const vectorIndex = await VectorStoreIndex.fromDocuments(documents, { + serviceContext, + }); + + const summaryIndex = await SummaryIndex.fromDocuments(documents, { + serviceContext, + }); + + // Create query engines + const vectorQueryEngine = vectorIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + // Create a router query engine + const queryEngine = RouterQueryEngine.fromDefaults({ + queryEngineTools: [ + { + queryEngine: vectorQueryEngine, + description: "Useful for summarization questions related to Abramov", + }, + { + queryEngine: summaryQueryEngine, + description: "Useful for retrieving specific context from Abramov", + }, + ], + serviceContext, + }); + + // Query the router query engine + const summaryResponse = await queryEngine.query({ + query: "Give me a summary about his past experiences?", + }); + + console.log({ + answer: summaryResponse.response, + metadata: summaryResponse?.metadata?.selectorResult, + }); + + const specificResponse = await queryEngine.query({ + query: "Tell me about abramov first job?", + }); + + console.log({ + answer: specificResponse.response, + metadata: specificResponse.metadata.selectorResult, + }); +} + +main().then(() => console.log("Done")); +``` diff --git a/examples/routerQueryEngine.ts b/examples/routerQueryEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..57075f0db5317f934fc65d19ca70adb253725f25 --- /dev/null +++ b/examples/routerQueryEngine.ts @@ -0,0 +1,76 @@ +import { + OpenAI, + RouterQueryEngine, + SimpleDirectoryReader, + SimpleNodeParser, + SummaryIndex, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +async function main() { + // Load documents from a directory + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: "node_modules/llamaindex/examples", + }); + + // Parse the documents into nodes + const nodeParser = new SimpleNodeParser({ + chunkSize: 1024, + }); + + // Create a service context + const serviceContext = serviceContextFromDefaults({ + nodeParser, + llm: new OpenAI(), + }); + + // Create indices + const vectorIndex = await VectorStoreIndex.fromDocuments(documents, { + serviceContext, + }); + + const summaryIndex = await SummaryIndex.fromDocuments(documents, { + serviceContext, + }); + + // Create query engines + const vectorQueryEngine = vectorIndex.asQueryEngine(); + const summaryQueryEngine = summaryIndex.asQueryEngine(); + + // Create a router query engine + const queryEngine = RouterQueryEngine.fromDefaults({ + queryEngineTools: [ + { + queryEngine: vectorQueryEngine, + description: "Useful for summarization questions related to Abramov", + }, + { + queryEngine: summaryQueryEngine, + description: "Useful for retrieving specific context from Abramov", + }, + ], + serviceContext, + }); + + // Query the router query engine + const summaryResponse = await queryEngine.query({ + query: "Give me a summary about his past experiences?", + }); + + console.log({ + answer: summaryResponse.response, + metadata: summaryResponse?.metadata?.selectorResult, + }); + + const specificResponse = await queryEngine.query({ + query: "Tell me about abramov first job?", + }); + + console.log({ + answer: specificResponse.response, + metadata: specificResponse.metadata.selectorResult, + }); +} + +main().then(() => console.log("Done")); diff --git a/packages/core/src/OutputParser.ts b/packages/core/src/OutputParser.ts index 052eaf8c4aa62b970e2192b547d2aa84717500e7..2b09bb4b7823ddf52dd3162935f3da0050403b8e 100644 --- a/packages/core/src/OutputParser.ts +++ b/packages/core/src/OutputParser.ts @@ -1,4 +1,5 @@ -import { BaseOutputParser, StructuredOutput, SubQuestion } from "./types"; +import { SubQuestion } from "./engines/query/types"; +import { BaseOutputParser, StructuredOutput } from "./types"; /** * Error class for output parsing. Due to the nature of LLMs, anytime we use LLM diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 8e181fb9286dd033a0c732a61badfb1e18bc67db..a4d3da23ed9eee6f2a9bbc9163eb9b5787f652cf 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -1,5 +1,6 @@ +import { SubQuestion } from "./engines/query/types"; import { ChatMessage } from "./llm/types"; -import { SubQuestion, ToolMetadata } from "./types"; +import { ToolMetadata } from "./types"; /** * A SimplePrompt is a function that takes a dictionary of inputs and returns a string. diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index b52ed5b04d7421de44d51adc8963e8c16a36a891..ae4c0feb0e94ebbe26ba03a50fcfce3cc398af01 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -4,15 +4,10 @@ import { buildToolsText, defaultSubQuestionPrompt, } from "./Prompt"; +import { BaseQuestionGenerator, SubQuestion } from "./engines/query/types"; import { OpenAI } from "./llm/LLM"; import { LLM } from "./llm/types"; -import { - BaseOutputParser, - BaseQuestionGenerator, - StructuredOutput, - SubQuestion, - ToolMetadata, -} from "./types"; +import { BaseOutputParser, StructuredOutput, ToolMetadata } from "./types"; /** * LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query. diff --git a/packages/core/src/Response.ts b/packages/core/src/Response.ts index 9112efa657c25f5c23c824c1cf72ca89d9142ec4..6f651587611d16af3c09db4eec59f72d9f5a4b19 100644 --- a/packages/core/src/Response.ts +++ b/packages/core/src/Response.ts @@ -6,6 +6,7 @@ import { BaseNode } from "./Node"; export class Response { response: string; sourceNodes?: BaseNode[]; + metadata: Record<string, unknown> = {}; constructor(response: string, sourceNodes?: BaseNode[]) { this.response = response; diff --git a/packages/core/src/engines/query/RetrieverQueryEngine.ts b/packages/core/src/engines/query/RetrieverQueryEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..51470a387e4a1a64afb3f49b0012042387358b19 --- /dev/null +++ b/packages/core/src/engines/query/RetrieverQueryEngine.ts @@ -0,0 +1,82 @@ +import { NodeWithScore } from "../../Node"; +import { Response } from "../../Response"; +import { BaseRetriever } from "../../Retriever"; +import { ServiceContext } from "../../ServiceContext"; +import { Event } from "../../callbacks/CallbackManager"; +import { randomUUID } from "../../env"; +import { BaseNodePostprocessor } from "../../postprocessors"; +import { BaseSynthesizer, ResponseSynthesizer } from "../../synthesizers"; +import { + BaseQueryEngine, + QueryEngineParamsNonStreaming, + QueryEngineParamsStreaming, +} from "../../types"; + +/** + * A query engine that uses a retriever to query an index and then synthesizes the response. + */ +export class RetrieverQueryEngine implements BaseQueryEngine { + retriever: BaseRetriever; + responseSynthesizer: BaseSynthesizer; + nodePostprocessors: BaseNodePostprocessor[]; + preFilters?: unknown; + + constructor( + retriever: BaseRetriever, + responseSynthesizer?: BaseSynthesizer, + preFilters?: unknown, + nodePostprocessors?: BaseNodePostprocessor[], + ) { + this.retriever = retriever; + const serviceContext: ServiceContext | undefined = + this.retriever.getServiceContext(); + this.responseSynthesizer = + responseSynthesizer || new ResponseSynthesizer({ serviceContext }); + this.preFilters = preFilters; + this.nodePostprocessors = nodePostprocessors || []; + } + + private applyNodePostprocessors(nodes: NodeWithScore[]) { + return this.nodePostprocessors.reduce( + (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), + nodes, + ); + } + + private async retrieve(query: string, parentEvent: Event) { + const nodes = await this.retriever.retrieve( + query, + parentEvent, + this.preFilters, + ); + + return this.applyNodePostprocessors(nodes); + } + + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; + const parentEvent: Event = params.parentEvent || { + id: randomUUID(), + type: "wrapper", + tags: ["final"], + }; + const nodesWithScore = await this.retrieve(query, parentEvent); + if (stream) { + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + stream: true, + }); + } + return this.responseSynthesizer.synthesize({ + query, + nodesWithScore, + parentEvent, + }); + } +} diff --git a/packages/core/src/engines/query/RouterQueryEngine.ts b/packages/core/src/engines/query/RouterQueryEngine.ts new file mode 100644 index 0000000000000000000000000000000000000000..babe01d14729e057e2e40948002ef657498dcfe3 --- /dev/null +++ b/packages/core/src/engines/query/RouterQueryEngine.ts @@ -0,0 +1,181 @@ +import { Response } from "../../Response"; +import { + ServiceContext, + serviceContextFromDefaults, +} from "../../ServiceContext"; +import { BaseSelector, LLMSingleSelector } from "../../selectors"; +import { TreeSummarize } from "../../synthesizers"; +import { + BaseQueryEngine, + QueryBundle, + QueryEngineParamsNonStreaming, + QueryEngineParamsStreaming, +} from "../../types"; + +type RouterQueryEngineTool = { + queryEngine: BaseQueryEngine; + description: string; +}; + +type RouterQueryEngineMetadata = { + description: string; +}; + +async function combineResponses( + summarizer: TreeSummarize, + responses: Response[], + queryBundle: QueryBundle, + verbose: boolean = false, +): Promise<Response> { + if (verbose) { + console.log("Combining responses from multiple query engines."); + } + + const responseStrs = []; + const sourceNodes = []; + + for (const response of responses) { + if (response?.sourceNodes) { + sourceNodes.push(...response.sourceNodes); + } + + responseStrs.push(response.response); + } + + const summary = await summarizer.getResponse({ + query: queryBundle.queryStr, + textChunks: responseStrs, + }); + + return new Response(summary, sourceNodes); +} + +/** + * A query engine that uses multiple query engines and selects the best one. + */ +export class RouterQueryEngine implements BaseQueryEngine { + serviceContext: ServiceContext; + + private selector: BaseSelector; + private queryEngines: BaseQueryEngine[]; + private metadatas: RouterQueryEngineMetadata[]; + private summarizer: TreeSummarize; + private verbose: boolean; + + constructor(init: { + selector: BaseSelector; + queryEngineTools: RouterQueryEngineTool[]; + serviceContext?: ServiceContext; + summarizer?: TreeSummarize; + verbose?: boolean; + }) { + this.serviceContext = init.serviceContext || serviceContextFromDefaults({}); + this.selector = init.selector; + this.queryEngines = init.queryEngineTools.map((tool) => tool.queryEngine); + this.metadatas = init.queryEngineTools.map((tool) => ({ + description: tool.description, + })); + this.summarizer = init.summarizer || new TreeSummarize(this.serviceContext); + this.verbose = init.verbose ?? false; + } + + static fromDefaults(init: { + queryEngineTools: RouterQueryEngineTool[]; + selector?: BaseSelector; + serviceContext?: ServiceContext; + summarizer?: TreeSummarize; + verbose?: boolean; + }) { + const serviceContext = + init.serviceContext ?? serviceContextFromDefaults({}); + + return new RouterQueryEngine({ + selector: + init.selector ?? new LLMSingleSelector({ llm: serviceContext.llm }), + queryEngineTools: init.queryEngineTools, + serviceContext, + summarizer: init.summarizer, + verbose: init.verbose, + }); + } + + query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; + query(params: QueryEngineParamsNonStreaming): Promise<Response>; + async query( + params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, + ): Promise<Response | AsyncIterable<Response>> { + const { query, stream } = params; + + const response = await this.queryRoute({ queryStr: query }); + + if (stream) { + throw new Error("Streaming is not supported yet."); + } + + return response; + } + + private async queryRoute(queryBundle: QueryBundle): Promise<Response> { + const result = await this.selector.select(this.metadatas, queryBundle); + + if (result.selections.length > 1) { + const responses = []; + for (let i = 0; i < result.selections.length; i++) { + const engineInd = result.selections[i]; + const logStr = `Selecting query engine ${engineInd}: ${result.selections[i]}.`; + + if (this.verbose) { + console.log(logStr + "\n"); + } + + const selectedQueryEngine = this.queryEngines[engineInd.index]; + responses.push( + await selectedQueryEngine.query({ + query: queryBundle.queryStr, + }), + ); + } + + if (responses.length > 1) { + const finalResponse = await combineResponses( + this.summarizer, + responses, + queryBundle, + this.verbose, + ); + + return finalResponse; + } else { + return responses[0]; + } + } else { + let selectedQueryEngine; + + try { + selectedQueryEngine = this.queryEngines[result.selections[0].index]; + + const logStr = `Selecting query engine ${result.selections[0].index}: ${result.selections[0].reason}`; + + if (this.verbose) { + console.log(logStr + "\n"); + } + } catch (e) { + throw new Error("Failed to select query engine"); + } + + if (!selectedQueryEngine) { + throw new Error("Selected query engine is null"); + } + + const finalResponse = await selectedQueryEngine.query({ + query: queryBundle.queryStr, + }); + + // add selected result + finalResponse.metadata = finalResponse.metadata || {}; + finalResponse.metadata["selectorResult"] = result; + + return finalResponse; + } + } +} diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/engines/query/SubQuestionQueryEngine.ts similarity index 58% rename from packages/core/src/QueryEngine.ts rename to packages/core/src/engines/query/SubQuestionQueryEngine.ts index 8a3a2437eee25437f42237404a83713cb43758d6..a70dfbb9e98ffc24037aabe9883fcdf763633b23 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/engines/query/SubQuestionQueryEngine.ts @@ -1,94 +1,25 @@ -import { NodeWithScore, TextNode } from "./Node"; -import { LLMQuestionGenerator } from "./QuestionGenerator"; -import { Response } from "./Response"; -import { BaseRetriever } from "./Retriever"; -import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; -import { Event } from "./callbacks/CallbackManager"; -import { randomUUID } from "./env"; -import { BaseNodePostprocessor } from "./postprocessors"; +import { NodeWithScore, TextNode } from "../../Node"; +import { LLMQuestionGenerator } from "../../QuestionGenerator"; +import { Response } from "../../Response"; +import { + ServiceContext, + serviceContextFromDefaults, +} from "../../ServiceContext"; +import { Event } from "../../callbacks/CallbackManager"; +import { randomUUID } from "../../env"; import { BaseSynthesizer, CompactAndRefine, ResponseSynthesizer, -} from "./synthesizers"; +} from "../../synthesizers"; import { BaseQueryEngine, - BaseQuestionGenerator, QueryEngineParamsNonStreaming, QueryEngineParamsStreaming, QueryEngineTool, - SubQuestion, ToolMetadata, -} from "./types"; - -/** - * A query engine that uses a retriever to query an index and then synthesizes the response. - */ -export class RetrieverQueryEngine implements BaseQueryEngine { - retriever: BaseRetriever; - responseSynthesizer: BaseSynthesizer; - nodePostprocessors: BaseNodePostprocessor[]; - preFilters?: unknown; - - constructor( - retriever: BaseRetriever, - responseSynthesizer?: BaseSynthesizer, - preFilters?: unknown, - nodePostprocessors?: BaseNodePostprocessor[], - ) { - this.retriever = retriever; - const serviceContext: ServiceContext | undefined = - this.retriever.getServiceContext(); - this.responseSynthesizer = - responseSynthesizer || new ResponseSynthesizer({ serviceContext }); - this.preFilters = preFilters; - this.nodePostprocessors = nodePostprocessors || []; - } - - private applyNodePostprocessors(nodes: NodeWithScore[]) { - return this.nodePostprocessors.reduce( - (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), - nodes, - ); - } - - private async retrieve(query: string, parentEvent: Event) { - const nodes = await this.retriever.retrieve( - query, - parentEvent, - this.preFilters, - ); - - return this.applyNodePostprocessors(nodes); - } - - query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>; - query(params: QueryEngineParamsNonStreaming): Promise<Response>; - async query( - params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming, - ): Promise<Response | AsyncIterable<Response>> { - const { query, stream } = params; - const parentEvent: Event = params.parentEvent || { - id: randomUUID(), - type: "wrapper", - tags: ["final"], - }; - const nodesWithScore = await this.retrieve(query, parentEvent); - if (stream) { - return this.responseSynthesizer.synthesize({ - query, - nodesWithScore, - parentEvent, - stream: true, - }); - } - return this.responseSynthesizer.synthesize({ - query, - nodesWithScore, - parentEvent, - }); - } -} +} from "../../types"; +import { BaseQuestionGenerator, SubQuestion } from "./types"; /** * SubQuestionQueryEngine decomposes a question into subquestions and then diff --git a/packages/core/src/engines/query/index.ts b/packages/core/src/engines/query/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..ef1cb9f8341ea1835ab92bc9bafa33b6f36416ab --- /dev/null +++ b/packages/core/src/engines/query/index.ts @@ -0,0 +1,3 @@ +export * from "./RetrieverQueryEngine"; +export * from "./RouterQueryEngine"; +export * from "./SubQuestionQueryEngine"; diff --git a/packages/core/src/engines/query/types.ts b/packages/core/src/engines/query/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..4fd9c63dd838bebfe3a764cc9f16389893510485 --- /dev/null +++ b/packages/core/src/engines/query/types.ts @@ -0,0 +1,13 @@ +import { ToolMetadata } from "../../types"; + +/** + * QuestionGenerators generate new questions for the LLM using tools and a user query. + */ +export interface BaseQuestionGenerator { + generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; +} + +export interface SubQuestion { + subQuestion: string; + toolName: string; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index a668a4b9e50baf5bd9e57e98d4128ad605a1db86..0d02cbc6e8de89d94e3335406b8d846509b9fb17 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -4,7 +4,6 @@ export * from "./Node"; export * from "./OutputParser"; export * from "./Prompt"; export * from "./PromptHelper"; -export * from "./QueryEngine"; export * from "./QuestionGenerator"; export * from "./Response"; export * from "./Retriever"; @@ -14,6 +13,7 @@ export * from "./callbacks/CallbackManager"; export * from "./constants"; export * from "./embeddings"; export * from "./engines/chat"; +export * from "./engines/query"; export * from "./extractors"; export * from "./indices"; export * from "./ingestion"; @@ -30,6 +30,7 @@ export * from "./readers/PDFReader"; export * from "./readers/SimpleDirectoryReader"; export * from "./readers/SimpleMongoReader"; export * from "./readers/base"; +export * from "./selectors"; export * from "./storage"; export * from "./synthesizers"; export type * from "./types"; diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts index 8072aeb63673be9b6842d0d962dca4e9713d3046..6964ebb2deb4f518d7f4ac8e8331506a5317ebe9 100644 --- a/packages/core/src/indices/keyword/KeywordTableIndex.ts +++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts @@ -1,11 +1,11 @@ import { BaseNode, Document, MetadataMode } from "../../Node"; import { defaultKeywordExtractPrompt } from "../../Prompt"; -import { RetrieverQueryEngine } from "../../QueryEngine"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { RetrieverQueryEngine } from "../../engines/query"; import { BaseNodePostprocessor } from "../../postprocessors"; import { BaseDocumentStore, diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts index 774dfac2a4ec9bc6c5b52c8b6e83752df9cb3127..adcff31dee32a9d0448a50f6f9dc6637707dfdbe 100644 --- a/packages/core/src/indices/summary/SummaryIndex.ts +++ b/packages/core/src/indices/summary/SummaryIndex.ts @@ -1,11 +1,11 @@ import _ from "lodash"; import { BaseNode, Document } from "../../Node"; -import { RetrieverQueryEngine } from "../../QueryEngine"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { RetrieverQueryEngine } from "../../engines/query"; import { BaseNodePostprocessor } from "../../postprocessors"; import { BaseDocumentStore, diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 2403e1b54f4ac4366f8533f6caadc652e1dd6a67..75a95fbf75e51a1c3837684929ccd44e6e18132c 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -6,7 +6,6 @@ import { ObjectType, splitNodesByType, } from "../../Node"; -import { RetrieverQueryEngine } from "../../QueryEngine"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, @@ -17,6 +16,7 @@ import { ClipEmbedding, MultiModalEmbedding, } from "../../embeddings"; +import { RetrieverQueryEngine } from "../../engines/query"; import { runTransformations } from "../../ingestion"; import { BaseNodePostprocessor } from "../../postprocessors"; import { diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 3df872512f4e497205feaf6009bbc67bcbb77a2f..e543164ef933e92b8c5ede08cf11f07a4e578c44 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -216,6 +216,7 @@ export class OpenAI extends BaseLLM { top_p: this.topP, ...this.additionalChatOptions, }; + // Streaming if (stream) { return this.streamChat(params); diff --git a/packages/core/src/outputParsers/selectors.ts b/packages/core/src/outputParsers/selectors.ts new file mode 100644 index 0000000000000000000000000000000000000000..9c0af07e271d3c54e227bdcf688886e02996dabd --- /dev/null +++ b/packages/core/src/outputParsers/selectors.ts @@ -0,0 +1,52 @@ +import { parseJsonMarkdown } from "../OutputParser"; +import { BaseOutputParser, StructuredOutput } from "../types"; + +export type Answer = { + choice: number; + reason: string; +}; + +const formatStr = `The output should be ONLY JSON formatted as a JSON instance. + +Here is an example: +[ + { + choice: 1, + reason: "<insert reason for choice>" + }, + ... +] +`; + +/* + * An OutputParser is used to extract structured data from the raw output of the LLM. + */ +export class SelectionOutputParser + implements BaseOutputParser<StructuredOutput<Answer[]>> +{ + /** + * + * @param output + */ + parse(output: string): StructuredOutput<Answer[]> { + let parsed; + + try { + parsed = parseJsonMarkdown(output); + } catch (e) { + try { + parsed = JSON.parse(output); + } catch (e) { + throw new Error( + `Got invalid JSON object. Error: ${e}. Got JSON string: ${output}`, + ); + } + } + + return { rawOutput: output, parsedOutput: parsed }; + } + + format(output: string): string { + return output + "\n\n" + formatStr; + } +} diff --git a/packages/core/src/selectors/base.ts b/packages/core/src/selectors/base.ts new file mode 100644 index 0000000000000000000000000000000000000000..74bd89e372f1afb07cd095d9637bbcfa3ad0ab19 --- /dev/null +++ b/packages/core/src/selectors/base.ts @@ -0,0 +1,45 @@ +import { QueryBundle, ToolMetadataOnlyDescription } from "../types"; + +export interface SingleSelection { + index: number; + reason: string; +} + +export type SelectorResult = { + selections: SingleSelection[]; +}; + +type QueryType = string | QueryBundle; + +function wrapChoice( + choice: string | ToolMetadataOnlyDescription, +): ToolMetadataOnlyDescription { + if (typeof choice === "string") { + return { description: choice }; + } else { + return choice; + } +} + +function wrapQuery(query: QueryType): QueryBundle { + if (typeof query === "string") { + return { queryStr: query }; + } + + return query; +} + +type MetadataType = string | ToolMetadataOnlyDescription; + +export abstract class BaseSelector { + async select(choices: MetadataType[], query: QueryType) { + const metadatas = choices.map((choice) => wrapChoice(choice)); + const queryBundle = wrapQuery(query); + return await this._select(metadatas, queryBundle); + } + + abstract _select( + choices: ToolMetadataOnlyDescription[], + query: QueryBundle, + ): Promise<SelectorResult>; +} diff --git a/packages/core/src/selectors/index.ts b/packages/core/src/selectors/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..c3dd0aaf7badf6809b04ef7c7cb51e303e1514af --- /dev/null +++ b/packages/core/src/selectors/index.ts @@ -0,0 +1,3 @@ +export * from "./base"; +export * from "./llmSelectors"; +export * from "./utils"; diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts new file mode 100644 index 0000000000000000000000000000000000000000..74acdc87f8c8065913fe889759ff9ab95990a64f --- /dev/null +++ b/packages/core/src/selectors/llmSelectors.ts @@ -0,0 +1,168 @@ +import { DefaultPromptTemplate } from "../extractors/prompts"; +import { LLM } from "../llm"; +import { Answer, SelectionOutputParser } from "../outputParsers/selectors"; +import { + BaseOutputParser, + QueryBundle, + StructuredOutput, + ToolMetadataOnlyDescription, +} from "../types"; +import { BaseSelector, SelectorResult } from "./base"; +import { defaultSingleSelectPrompt } from "./prompts"; + +function buildChoicesText(choices: ToolMetadataOnlyDescription[]): string { + const texts: string[] = []; + for (const [ind, choice] of choices.entries()) { + let text = choice.description.split("\n").join(" "); + text = `(${ind + 1}) ${text}`; // to one indexing + texts.push(text); + } + return texts.join(""); +} + +function _structuredOutputToSelectorResult( + output: StructuredOutput<Answer[]>, +): SelectorResult { + const structuredOutput = output; + const answers = structuredOutput.parsedOutput; + + // adjust for zero indexing + const selections = answers.map((answer: any) => { + return { index: answer.choice - 1, reason: answer.reason }; + }); + + return { selections }; +} + +type LLMPredictorType = LLM; + +/** + * A selector that uses the LLM to select a single or multiple choices from a list of choices. + */ +export class LLMMultiSelector extends BaseSelector { + _llm: LLMPredictorType; + _prompt: DefaultPromptTemplate | undefined; + _maxOutputs: number | null; + _outputParser: BaseOutputParser<any> | null; + + constructor(init: { + llm: LLMPredictorType; + prompt?: DefaultPromptTemplate; + maxOutputs?: number; + outputParser?: BaseOutputParser<any>; + }) { + super(); + this._llm = init.llm; + this._prompt = init.prompt; + this._maxOutputs = init.maxOutputs ?? null; + + this._outputParser = init.outputParser ?? new SelectionOutputParser(); + } + + _getPrompts(): Record<string, any> { + return { prompt: this._prompt }; + } + + _updatePrompts(prompts: Record<string, any>): void { + if ("prompt" in prompts) { + this._prompt = prompts.prompt; + } + } + + /** + * Selects a single choice from a list of choices. + * @param choices + * @param query + */ + async _select( + choices: ToolMetadataOnlyDescription[], + query: QueryBundle, + ): Promise<SelectorResult> { + const choicesText = buildChoicesText(choices); + + const prompt = + this._prompt?.contextStr ?? + defaultSingleSelectPrompt( + choicesText.length, + choicesText, + query.queryStr, + ); + const formattedPrompt = this._outputParser?.format(prompt); + + const prediction = await this._llm.complete({ + prompt: formattedPrompt, + }); + + const parsed = this._outputParser?.parse(prediction.text); + + return _structuredOutputToSelectorResult(parsed); + } + + asQueryComponent(): unknown { + throw new Error("Method not implemented."); + } +} + +/** + * A selector that uses the LLM to select a single choice from a list of choices. + */ +export class LLMSingleSelector extends BaseSelector { + _llm: LLMPredictorType; + _prompt: DefaultPromptTemplate | undefined; + _outputParser: BaseOutputParser<any> | null; + + constructor(init: { + llm: LLMPredictorType; + prompt?: DefaultPromptTemplate; + outputParser?: BaseOutputParser<any>; + }) { + super(); + this._llm = init.llm; + this._prompt = init.prompt; + this._outputParser = init.outputParser ?? new SelectionOutputParser(); + } + + _getPrompts(): Record<string, any> { + return { prompt: this._prompt }; + } + + _updatePrompts(prompts: Record<string, any>): void { + if ("prompt" in prompts) { + this._prompt = prompts.prompt; + } + } + + /** + * Selects a single choice from a list of choices. + * @param choices + * @param query + */ + async _select( + choices: ToolMetadataOnlyDescription[], + query: QueryBundle, + ): Promise<SelectorResult> { + const choicesText = buildChoicesText(choices); + + const prompt = + this._prompt?.contextStr ?? + defaultSingleSelectPrompt( + choicesText.length, + choicesText, + query.queryStr, + ); + + const formattedPrompt = this._outputParser?.format(prompt); + + const prediction = await this._llm.complete({ + prompt: formattedPrompt, + }); + + const parsed = this._outputParser?.parse(prediction.text); + + return _structuredOutputToSelectorResult(parsed); + } + + asQueryComponent(): unknown { + throw new Error("Method not implemented."); + } +} diff --git a/packages/core/src/selectors/prompts.ts b/packages/core/src/selectors/prompts.ts new file mode 100644 index 0000000000000000000000000000000000000000..b915271214a9edea4858ed51e2d2ae2e7a5e0d12 --- /dev/null +++ b/packages/core/src/selectors/prompts.ts @@ -0,0 +1,30 @@ +export const defaultSingleSelectPrompt = ( + numChoices: number, + contextList: string, + queryStr: string, +): string => { + return `Some choices are given below. It is provided in a numbered list (1 to ${numChoices}), where each item in the list corresponds to a summary. +--------------------- +${contextList} +--------------------- +Using only the choices above and not prior knowledge, return the choice that is most relevant to the question: '${queryStr}' +`; +}; + +export type SingleSelectPrompt = typeof defaultSingleSelectPrompt; + +export const defaultMultiSelectPrompt = ( + numChoices: number, + contextList: string, + queryStr: string, + maxOutputs: number, +) => { + return `Some choices are given below. It is provided in a numbered list (1 to ${numChoices}), where each item in the list corresponds to a summary. +--------------------- +${contextList} +--------------------- +Using only the choices above and not prior knowledge, return the top choices (no more than ${maxOutputs}, but only select what is needed) that are most relevant to the question: '${queryStr}' +`; +}; + +export type MultiSelectPrompt = typeof defaultMultiSelectPrompt; diff --git a/packages/core/src/selectors/utils.ts b/packages/core/src/selectors/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..08b2226e165f07e4b1b62b727aa5259754454e61 --- /dev/null +++ b/packages/core/src/selectors/utils.ts @@ -0,0 +1,24 @@ +import { ServiceContext } from "../ServiceContext"; +import { BaseSelector } from "./base"; +import { LLMMultiSelector, LLMSingleSelector } from "./llmSelectors"; + +export const getSelectorFromContext = ( + serviceContext: ServiceContext, + isMulti: boolean = false, +): BaseSelector => { + let selector: BaseSelector | null = null; + + const llm = serviceContext.llm; + + if (isMulti) { + selector = new LLMMultiSelector({ llm }); + } else { + selector = new LLMSingleSelector({ llm }); + } + + if (selector === null) { + throw new Error("Selector is null"); + } + + return selector; +}; diff --git a/packages/core/src/tests/Selectors.test.ts b/packages/core/src/tests/Selectors.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..a615f03fbf9642151037b848d026054d5d17d0b0 --- /dev/null +++ b/packages/core/src/tests/Selectors.test.ts @@ -0,0 +1,38 @@ +// from unittest.mock import patch + +import { serviceContextFromDefaults } from "../ServiceContext"; +import { OpenAI } from "../llm"; +import { LLMSingleSelector } from "../selectors"; +import { mocStructuredkLlmGeneration } from "./utility/mockOpenAI"; + +jest.mock("../llm/open_ai", () => { + return { + getOpenAISession: jest.fn().mockImplementation(() => null), + }; +}); + +describe("LLMSelector", () => { + test("should be able to output a selection with a reason", async () => { + const serviceContext = serviceContextFromDefaults({}); + + const languageModel = new OpenAI({ + model: "gpt-3.5-turbo", + }); + + mocStructuredkLlmGeneration({ + languageModel, + callbackManager: serviceContext.callbackManager, + }); + + const selector = new LLMSingleSelector({ + llm: languageModel, + }); + + const result = await selector.select( + ["apple", "pear", "peach"], + "what is the best fruit?", + ); + + expect(result.selections[0].reason).toBe("apple"); + }); +}); diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 84f6925d746f6256c9300f203ca64297f6853abe..e06053ed9384f0fffa29bebbaddc9965bc7b0bba 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -76,3 +76,68 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { }); }); } + +const structuredOutput = JSON.stringify([ + { + choice: 1, + reason: "apple", + }, +]); + +export function mocStructuredkLlmGeneration({ + languageModel, + callbackManager, +}: { + languageModel: OpenAI; + callbackManager: CallbackManager; +}) { + jest + .spyOn(languageModel, "chat") + .mockImplementation( + async ({ messages, parentEvent }: LLMChatParamsBase) => { + const text = structuredOutput; + const event = globalsHelper.createEvent({ + parentEvent, + type: "llmPredict", + }); + if (callbackManager?.onLLMStream) { + const chunks = text.split("-"); + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + callbackManager?.onLLMStream({ + event, + index: i, + token: { + id: "id", + object: "object", + created: 1, + model: "model", + choices: [ + { + index: 0, + delta: { + content: chunk, + }, + finish_reason: null, + }, + ], + }, + }); + } + callbackManager?.onLLMStream({ + event, + index: chunks.length, + isDone: true, + }); + } + return new Promise((resolve) => { + resolve({ + message: { + content: text, + role: "assistant", + }, + }); + }); + }, + ); +} diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 4f38fdcaf342e9e4964686fe857e3e3503ebb953..3f8074980fbeef55fd7c964d78266529b5a7c89b 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -47,11 +47,6 @@ export interface QueryEngineTool extends BaseTool { queryEngine: BaseQueryEngine; } -export interface SubQuestion { - subQuestion: string; - toolName: string; -} - /** * An OutputParser is used to extract structured data from the raw output of the LLM. */ @@ -74,9 +69,16 @@ export interface ToolMetadata { name: string; } -/** - * QuestionGenerators generate new questions for the LLM using tools and a user query. - */ -export interface BaseQuestionGenerator { - generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>; +export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">; + +export class QueryBundle { + queryStr: string; + + constructor(queryStr: string) { + this.queryStr = queryStr; + } + + toString(): string { + return this.queryStr; + } }