diff --git a/.changeset/quiet-cows-rule.md b/.changeset/quiet-cows-rule.md new file mode 100644 index 0000000000000000000000000000000000000000..53cd41ea423b0b1e08f8507496e3958d0371690b --- /dev/null +++ b/.changeset/quiet-cows-rule.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: use query bundle diff --git a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts index 3239e1e84768d890adad3cb24c5ca006966799ed..408d70c4d91bcbb91905b70e0973d47c75b5c4a9 100644 --- a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts @@ -1,7 +1,9 @@ import type { NodeWithScore } from "@llamaindex/core/schema"; +import { extractText } from "@llamaindex/core/utils"; import { EngineResponse } from "../../EngineResponse.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { llmFromSettingsOrContext } from "../../Settings.js"; +import { toQueryBundle } from "../../internal/utils.js"; import { PromptMixin } from "../../prompts/index.js"; import type { BaseSelector } from "../../selectors/index.js"; import { LLMSingleSelector } from "../../selectors/index.js"; @@ -44,7 +46,7 @@ async function combineResponses( } const summary = await summarizer.getResponse({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), textChunks: responseStrs, }); @@ -117,7 +119,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { ): Promise<EngineResponse | AsyncIterable<EngineResponse>> { const { query, stream } = params; - const response = await this.queryRoute({ queryStr: query }); + const response = await this.queryRoute(toQueryBundle(query)); if (stream) { throw new Error("Streaming is not supported yet."); @@ -142,7 +144,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { const selectedQueryEngine = this.queryEngines[engineInd.index]; responses.push( await selectedQueryEngine.query({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), }), ); } @@ -179,7 +181,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine { } const finalResponse = await selectedQueryEngine.query({ - query: queryBundle.queryStr, + query: extractText(queryBundle.query), }); // add selected result diff --git a/packages/llamaindex/src/internal/utils.ts b/packages/llamaindex/src/internal/utils.ts index a301c2707997ec3598ce3f9c7fcea7470a4d938d..c3395bea4e4b775f35dee5c71aab3e97aef96c05 100644 --- a/packages/llamaindex/src/internal/utils.ts +++ b/packages/llamaindex/src/internal/utils.ts @@ -3,6 +3,7 @@ import type { JSONValue } from "@llamaindex/core/global"; import type { ImageType } from "@llamaindex/core/schema"; import { fs } from "@llamaindex/env"; import { filetypemime } from "magic-bytes.js"; +import type { QueryBundle } from "../types.js"; export const isAsyncIterable = ( obj: unknown, @@ -202,3 +203,10 @@ export async function imageToDataUrl(input: ImageType): Promise<string> { } return await blobToDataUrl(input); } + +export function toQueryBundle(query: QueryBundle | string): QueryBundle { + if (typeof query === "string") { + return { query }; + } + return query; +} diff --git a/packages/llamaindex/src/prompts/Mixin.ts b/packages/llamaindex/src/prompts/Mixin.ts index b49d2db12dbf7da2ff59695745145de7bdc4d1cb..5d09421371994d07c13a841ba3142b029bdec9dd 100644 --- a/packages/llamaindex/src/prompts/Mixin.ts +++ b/packages/llamaindex/src/prompts/Mixin.ts @@ -75,6 +75,7 @@ export class PromptMixin { } // Must be implemented by subclasses + // fixme: says must but never implemented protected _getPrompts(): PromptsDict { return {}; } diff --git a/packages/llamaindex/src/selectors/base.ts b/packages/llamaindex/src/selectors/base.ts index a5ef61b41fcd0b1dc3855ee54bd768cf96190db0..22a5c66da95f52ab36b9b4a3499823aa3be90025 100644 --- a/packages/llamaindex/src/selectors/base.ts +++ b/packages/llamaindex/src/selectors/base.ts @@ -1,3 +1,4 @@ +import { toQueryBundle } from "../internal/utils.js"; import { PromptMixin } from "../prompts/Mixin.js"; import type { QueryBundle, ToolMetadataOnlyDescription } from "../types.js"; @@ -10,8 +11,6 @@ export type SelectorResult = { selections: SingleSelection[]; }; -type QueryType = string | QueryBundle; - function wrapChoice( choice: string | ToolMetadataOnlyDescription, ): ToolMetadataOnlyDescription { @@ -22,21 +21,13 @@ function wrapChoice( } } -function wrapQuery(query: QueryType): QueryBundle { - if (typeof query === "string") { - return { queryStr: query }; - } - - return query; -} - type MetadataType = string | ToolMetadataOnlyDescription; export abstract class BaseSelector extends PromptMixin { - async select(choices: MetadataType[], query: QueryType) { - const metadatas = choices.map((choice) => wrapChoice(choice)); - const queryBundle = wrapQuery(query); - return await this._select(metadatas, queryBundle); + async select(choices: MetadataType[], query: string | QueryBundle) { + const metadata = choices.map((choice) => wrapChoice(choice)); + const queryBundle = toQueryBundle(query); + return await this._select(metadata, queryBundle); } abstract _select( diff --git a/packages/llamaindex/src/selectors/llmSelectors.ts b/packages/llamaindex/src/selectors/llmSelectors.ts index 242b9973a25d95084ec500912cd2fa2c228f9913..e73966bfa11350e35d9c088a9d1b2fc1061c2723 100644 --- a/packages/llamaindex/src/selectors/llmSelectors.ts +++ b/packages/llamaindex/src/selectors/llmSelectors.ts @@ -1,4 +1,5 @@ import type { LLM } from "@llamaindex/core/llms"; +import { extractText } from "@llamaindex/core/utils"; import type { Answer } from "../outputParsers/selectors.js"; import { SelectionOutputParser } from "../outputParsers/selectors.js"; import type { @@ -88,7 +89,7 @@ export class LLMMultiSelector extends BaseSelector { const prompt = this.prompt( choicesText.length, choicesText, - query.queryStr, + extractText(query.query), this.maxOutputs, ); @@ -152,7 +153,11 @@ export class LLMSingleSelector extends BaseSelector { ): Promise<SelectorResult> { const choicesText = buildChoicesText(choices); - const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); + const prompt = this.prompt( + choicesText.length, + choicesText, + extractText(query.query), + ); const formattedPrompt = this.outputParser.format(prompt); diff --git a/packages/llamaindex/src/synthesizers/builders.ts b/packages/llamaindex/src/synthesizers/builders.ts index 901b6728a3717c167e82403e5fb170f48961ac1a..f5ab12e814638a148ae95054d73e20c0406a1fe3 100644 --- a/packages/llamaindex/src/synthesizers/builders.ts +++ b/packages/llamaindex/src/synthesizers/builders.ts @@ -1,5 +1,6 @@ import type { LLM } from "@llamaindex/core/llms"; -import { streamConverter } from "@llamaindex/core/utils"; +import { extractText, streamConverter } from "@llamaindex/core/utils"; +import { toQueryBundle } from "../internal/utils.js"; import type { RefinePrompt, SimplePrompt, @@ -61,7 +62,7 @@ export class SimpleResponseBuilder implements ResponseBuilder { AsyncIterable<string> | string > { const input = { - query, + query: extractText(toQueryBundle(query).query), context: textChunks.join("\n\n"), }; @@ -142,14 +143,14 @@ export class Refine extends PromptMixin implements ResponseBuilder { const lastChunk = i === textChunks.length - 1; if (!response) { response = await this.giveResponseSingle( - query, + extractText(toQueryBundle(query).query), chunk, !!stream && lastChunk, ); } else { response = await this.refineResponseSingle( response as string, - query, + extractText(toQueryBundle(query).query), chunk, !!stream && lastChunk, ); @@ -254,9 +255,15 @@ export class CompactAndRefine extends Refine { AsyncIterable<string> | string > { const textQATemplate: SimplePrompt = (input) => - this.textQATemplate({ ...input, query: query }); + this.textQATemplate({ + ...input, + query: extractText(toQueryBundle(query).query), + }); const refineTemplate: SimplePrompt = (input) => - this.refineTemplate({ ...input, query: query }); + this.refineTemplate({ + ...input, + query: extractText(toQueryBundle(query).query), + }); const maxPrompt = getBiggestPrompt([textQATemplate, refineTemplate]); const newTexts = this.promptHelper.repack(maxPrompt, textChunks); @@ -335,7 +342,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { const params = { prompt: this.summaryTemplate({ context: packedTextChunks[0], - query, + query: extractText(toQueryBundle(query).query), }), }; if (stream) { @@ -349,7 +356,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder { this.llm.complete({ prompt: this.summaryTemplate({ context: chunk, - query, + query: extractText(toQueryBundle(query).query), }), }), ), diff --git a/packages/llamaindex/src/types.ts b/packages/llamaindex/src/types.ts index 19d697e3c7aa1a5a3778f5183f5a8524f7e306a4..66cf1c5f7aee54890f8847b209b67035143584a8 100644 --- a/packages/llamaindex/src/types.ts +++ b/packages/llamaindex/src/types.ts @@ -1,7 +1,7 @@ /** * Top level types to avoid circular dependencies */ -import type { ToolMetadata } from "@llamaindex/core/llms"; +import type { MessageContent, ToolMetadata } from "@llamaindex/core/llms"; import type { EngineResponse } from "./EngineResponse.js"; /** @@ -52,16 +52,15 @@ export interface StructuredOutput<T> { export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">; -export class QueryBundle { - queryStr: string; - - constructor(queryStr: string) { - this.queryStr = queryStr; - } - - toString(): string { - return this.queryStr; - } -} +/** + * @link https://docs.llamaindex.ai/en/stable/api_reference/schema/?h=querybundle#llama_index.core.schema.QueryBundle + * + * We don't have `image_path` here, because it is included in the `query` field. + */ +export type QueryBundle = { + query: string | MessageContent; + customEmbedding?: string[]; + embeddings?: number[]; +}; export type UUID = `${string}-${string}-${string}-${string}-${string}`;