From 92f07824a75cab09ca635e982a3fe73f74eecf3f Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Wed, 17 Jul 2024 20:17:06 -0700 Subject: [PATCH] feat: use query bundle (#702) --- .changeset/quiet-cows-rule.md | 5 ++++ .../src/engines/query/RouterQueryEngine.ts | 10 ++++---- packages/llamaindex/src/internal/utils.ts | 8 +++++++ packages/llamaindex/src/prompts/Mixin.ts | 1 + packages/llamaindex/src/selectors/base.ts | 19 ++++----------- .../llamaindex/src/selectors/llmSelectors.ts | 9 ++++++-- .../llamaindex/src/synthesizers/builders.ts | 23 ++++++++++++------- packages/llamaindex/src/types.ts | 23 +++++++++---------- 8 files changed, 58 insertions(+), 40 deletions(-) create mode 100644 .changeset/quiet-cows-rule.md diff --git a/.changeset/quiet-cows-rule.md b/.changeset/quiet-cows-rule.md new file mode 100644 index 000000000..53cd41ea4 --- /dev/null +++ b/.changeset/quiet-cows-rule.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: use query bundle diff --git a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts index 3239e1e84..408d70c4d 100644 --- a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts @@ -1,7 +1,9 @@ import type { NodeWithScore } from "@llamaindex/core/schema"; +import { extractText } from "@llamaindex/core/utils"; import { EngineResponse } from "../../EngineResponse.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; +import { toQueryBundle } from "../../internal/utils.js"; import { PromptMixin } from "../../prompts/index.js"; import type { BaseSelector } from "../../selectors/index.js"; import { LLMSingleSelector } from "../../selectors/index.js"; @@ -44,7 +46,7 @@ async function combineResponses( } const summary = await summarizer.getResponse({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), textChunks: responseStrs, }); @@ -117,7 +119,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { const { query, stream } = params; - const response = await this.queryRoute({ queryStr: query }); + const response = await this.queryRoute(toQueryBundle(query)); if (stream) { throw new Error("Streaming is not supported yet."); @@ -142,7 +144,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { const selectedQueryEngine = this.queryEngines[engineInd.index]; responses.push( await selectedQueryEngine.query({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), }), ); } @@ -179,7 +181,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { } const finalResponse = await selectedQueryEngine.query({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), }); // add selected result diff --git a/packages/llamaindex/src/internal/utils.ts b/packages/llamaindex/src/internal/utils.ts index a301c2707..c3395bea4 100644 --- a/packages/llamaindex/src/internal/utils.ts +++ b/packages/llamaindex/src/internal/utils.ts @@ -3,6 +3,7 @@ import type { JSONValue } from "@llamaindex/core/global"; import type { ImageType } from "@llamaindex/core/schema"; import { fs } from "@llamaindex/env"; import { filetypemime } from "magic-bytes.js"; +import type { QueryBundle } from "../types.js"; export const isAsyncIterable = ( obj: unknown, @@ -202,3 +203,10 @@ export async function imageToDataUrl(input: ImageType): Promise<string> { } return await blobToDataUrl(input); } + +export function toQueryBundle(query: QueryBundle | string): QueryBundle { + if (typeof query === "string") { + return { query }; + } + return query; +} diff --git a/packages/llamaindex/src/prompts/Mixin.ts b/packages/llamaindex/src/prompts/Mixin.ts index b49d2db12..5d0942137 100644 --- a/packages/llamaindex/src/prompts/Mixin.ts +++ b/packages/llamaindex/src/prompts/Mixin.ts @@ -75,6 +75,7 @@ export class PromptMixin { } // Must be implemented by subclasses + // fixme: says must but never implemented protected _getPrompts(): PromptsDict { return {}; } diff --git a/packages/llamaindex/src/selectors/base.ts b/packages/llamaindex/src/selectors/base.ts index a5ef61b41..22a5c66da 100644 --- a/packages/llamaindex/src/selectors/base.ts +++ b/packages/llamaindex/src/selectors/base.ts @@ -1,3 +1,4 @@ +import { toQueryBundle } from "../internal/utils.js"; import { PromptMixin } from "../prompts/Mixin.js"; import type { QueryBundle, ToolMetadataOnlyDescription } from "../types.js"; @@ -10,8 +11,6 @@ export type SelectorResult = { selections: SingleSelection[]; }; -type QueryType = string | QueryBundle; - function wrapChoice( choice: string | ToolMetadataOnlyDescription, ): ToolMetadataOnlyDescription { @@ -22,21 +21,13 @@ function wrapChoice( } } -function wrapQuery(query: QueryType): QueryBundle { - if (typeof query === "string") { - return { queryStr: query }; - } - - return query; -} - type MetadataType = string | ToolMetadataOnlyDescription; export abstract class BaseSelector extends PromptMixin { - async select(choices: MetadataType[], query: QueryType) { - const metadatas = choices.map((choice) => wrapChoice(choice)); - const queryBundle = wrapQuery(query); - return await this._select(metadatas, queryBundle); + async select(choices: MetadataType[], query: string | QueryBundle) { + const metadata = choices.map((choice) => wrapChoice(choice)); + const queryBundle = toQueryBundle(query); + return await this._select(metadata, queryBundle); } abstract _select( diff --git a/packages/llamaindex/src/selectors/llmSelectors.ts b/packages/llamaindex/src/selectors/llmSelectors.ts index 242b9973a..e73966bfa 100644 --- a/packages/llamaindex/src/selectors/llmSelectors.ts +++ b/packages/llamaindex/src/selectors/llmSelectors.ts @@ -1,4 +1,5 @@ import type { LLM } from "@llamaindex/core/llms"; +import { extractText } from "@llamaindex/core/utils"; import type { Answer } from "../outputParsers/selectors.js"; import { SelectionOutputParser } from "../outputParsers/selectors.js"; import type { @@ -88,7 +89,7 @@ export class LLMMultiSelector extends BaseSelector { const prompt = this.prompt( choicesText.length, choicesText, - query.queryStr, + extractText(query.query), this.maxOutputs, ); @@ -152,7 +153,11 @@ export class LLMSingleSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); + const prompt = this.prompt( + choicesText.length, + choicesText, + extractText(query.query), + ); const formattedPrompt = this.outputParser.format(prompt); diff --git a/packages/llamaindex/src/synthesizers/builders.ts b/packages/llamaindex/src/synthesizers/builders.ts index 901b6728a..f5ab12e81 100644 --- a/packages/llamaindex/src/synthesizers/builders.ts +++ b/packages/llamaindex/src/synthesizers/builders.ts @@ -1,5 +1,6 @@ import type { LLM } from "@llamaindex/core/llms"; -import { streamConverter } from "@llamaindex/core/utils"; +import { extractText, streamConverter } from "@llamaindex/core/utils"; +import { toQueryBundle } from "../internal/utils.js"; import type { RefinePrompt, SimplePrompt, @@ -61,7 +62,7 @@ export class SimpleResponseBuilder implements ResponseBuilder { AsyncIterable<string> | string > { const input = { - query, + query: extractText(toQueryBundle(query).query), context: textChunks.join("\n\n"), }; @@ -142,14 +143,14 @@ export class Refine extends PromptMixin implements ResponseBuilder { const lastChunk = i === textChunks.length - 1; if (!response) { response = await this.giveResponseSingle( - query, + extractText(toQueryBundle(query).query), chunk, !!stream && lastChunk, ); } else { response = await this.refineResponseSingle( response as string, - query, + extractText(toQueryBundle(query).query), chunk, !!stream && lastChunk, ); @@ -254,9 +255,15 @@ export class CompactAndRefine extends Refine { AsyncIterable<string> | string > { const textQATemplate: SimplePrompt = (input) => - this.textQATemplate({ ...input, query: query }); + this.textQATemplate({ + ...input, + query: extractText(toQueryBundle(query).query), + }); const refineTemplate: SimplePrompt = (input) => - this.refineTemplate({ ...input, query: query }); + this.refineTemplate({ + ...input, + query: extractText(toQueryBundle(query).query), + }); const maxPrompt = getBiggestPrompt([textQATemplate, refineTemplate]); const newTexts = this.promptHelper.repack(maxPrompt, textChunks); @@ -335,7 +342,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { const params = { prompt: this.summaryTemplate({ context: packedTextChunks[0], - query, + query: extractText(toQueryBundle(query).query), }), }; if (stream) { @@ -349,7 +356,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { this.llm.complete({ prompt: this.summaryTemplate({ context: chunk, - query, + query: extractText(toQueryBundle(query).query), }), }), ), diff --git a/packages/llamaindex/src/types.ts b/packages/llamaindex/src/types.ts index 19d697e3c..66cf1c5f7 100644 --- a/packages/llamaindex/src/types.ts +++ b/packages/llamaindex/src/types.ts @@ -1,7 +1,7 @@ /** * Top level types to avoid circular dependencies */ -import type { ToolMetadata } from "@llamaindex/core/llms"; +import type { MessageContent, ToolMetadata } from "@llamaindex/core/llms"; import type { EngineResponse } from "./EngineResponse.js"; /** @@ -52,16 +52,15 @@ export interface StructuredOutput<T> { export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">; -export class QueryBundle { - queryStr: string; - - constructor(queryStr: string) { - this.queryStr = queryStr; - } - - toString(): string { - return this.queryStr; - } -} +/** + * @link https://docs.llamaindex.ai/en/stable/api_reference/schema/?h=querybundle#llama_index.core.schema.QueryBundle + * + * We don't have `image_path` here, because it is included in the `query` field. + */ +export type QueryBundle = { + query: string | MessageContent; + customEmbedding?: string[]; + embeddings?: number[]; +}; export type UUID = `${string}-${string}-${string}-${string}-${string}`; -- GitLab