From 92f07824a75cab09ca635e982a3fe73f74eecf3f Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Wed, 17 Jul 2024 20:17:06 -0700
Subject: [PATCH] feat: use query bundle (#702)

---
 .changeset/quiet-cows-rule.md                 |  5 ++++
 .../src/engines/query/RouterQueryEngine.ts    | 10 ++++----
 packages/llamaindex/src/internal/utils.ts     |  8 +++++++
 packages/llamaindex/src/prompts/Mixin.ts      |  1 +
 packages/llamaindex/src/selectors/base.ts     | 19 ++++-----------
 .../llamaindex/src/selectors/llmSelectors.ts  |  9 ++++++--
 .../llamaindex/src/synthesizers/builders.ts   | 23 ++++++++++++-------
 packages/llamaindex/src/types.ts              | 23 +++++++++----------
 8 files changed, 58 insertions(+), 40 deletions(-)
 create mode 100644 .changeset/quiet-cows-rule.md

diff --git a/.changeset/quiet-cows-rule.md b/.changeset/quiet-cows-rule.md
new file mode 100644
index 000000000..53cd41ea4
--- /dev/null
+++ b/.changeset/quiet-cows-rule.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat: use query bundle
diff --git a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts
index 3239e1e84..408d70c4d 100644
--- a/packages/llamaindex/src/engines/query/RouterQueryEngine.ts
+++ b/packages/llamaindex/src/engines/query/RouterQueryEngine.ts
@@ -1,7 +1,9 @@
 import type { NodeWithScore } from "@llamaindex/core/schema";
+import { extractText } from "@llamaindex/core/utils";
 import { EngineResponse } from "../../EngineResponse.js";
 import type { ServiceContext } from "../../ServiceContext.js";
 import { llmFromSettingsOrContext } from "../../Settings.js";
+import { toQueryBundle } from "../../internal/utils.js";
 import { PromptMixin } from "../../prompts/index.js";
 import type { BaseSelector } from "../../selectors/index.js";
 import { LLMSingleSelector } from "../../selectors/index.js";
@@ -44,7 +46,7 @@ async function combineResponses(
   }
 
   const summary = await summarizer.getResponse({
-    query: queryBundle.queryStr,
+    query: extractText(queryBundle.query),
     textChunks: responseStrs,
   });
 
@@ -117,7 +119,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
   ): Promise<EngineResponse | AsyncIterable<EngineResponse>> {
     const { query, stream } = params;
 
-    const response = await this.queryRoute({ queryStr: query });
+    const response = await this.queryRoute(toQueryBundle(query));
 
     if (stream) {
       throw new Error("Streaming is not supported yet.");
@@ -142,7 +144,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
         const selectedQueryEngine = this.queryEngines[engineInd.index];
         responses.push(
           await selectedQueryEngine.query({
-            query: queryBundle.queryStr,
+            query: extractText(queryBundle.query),
           }),
         );
       }
@@ -179,7 +181,7 @@ export class RouterQueryEngine extends PromptMixin implements QueryEngine {
       }
 
       const finalResponse = await selectedQueryEngine.query({
-        query: queryBundle.queryStr,
+        query: extractText(queryBundle.query),
       });
 
       // add selected result
diff --git a/packages/llamaindex/src/internal/utils.ts b/packages/llamaindex/src/internal/utils.ts
index a301c2707..c3395bea4 100644
--- a/packages/llamaindex/src/internal/utils.ts
+++ b/packages/llamaindex/src/internal/utils.ts
@@ -3,6 +3,7 @@ import type { JSONValue } from "@llamaindex/core/global";
 import type { ImageType } from "@llamaindex/core/schema";
 import { fs } from "@llamaindex/env";
 import { filetypemime } from "magic-bytes.js";
+import type { QueryBundle } from "../types.js";
 
 export const isAsyncIterable = (
   obj: unknown,
@@ -202,3 +203,10 @@ export async function imageToDataUrl(input: ImageType): Promise<string> {
   }
   return await blobToDataUrl(input);
 }
+
+export function toQueryBundle(query: QueryBundle | string): QueryBundle {
+  if (typeof query === "string") {
+    return { query };
+  }
+  return query;
+}
diff --git a/packages/llamaindex/src/prompts/Mixin.ts b/packages/llamaindex/src/prompts/Mixin.ts
index b49d2db12..5d0942137 100644
--- a/packages/llamaindex/src/prompts/Mixin.ts
+++ b/packages/llamaindex/src/prompts/Mixin.ts
@@ -75,6 +75,7 @@ export class PromptMixin {
   }
 
   // Must be implemented by subclasses
+  // fixme: says must but never implemented
   protected _getPrompts(): PromptsDict {
     return {};
   }
diff --git a/packages/llamaindex/src/selectors/base.ts b/packages/llamaindex/src/selectors/base.ts
index a5ef61b41..22a5c66da 100644
--- a/packages/llamaindex/src/selectors/base.ts
+++ b/packages/llamaindex/src/selectors/base.ts
@@ -1,3 +1,4 @@
+import { toQueryBundle } from "../internal/utils.js";
 import { PromptMixin } from "../prompts/Mixin.js";
 import type { QueryBundle, ToolMetadataOnlyDescription } from "../types.js";
 
@@ -10,8 +11,6 @@ export type SelectorResult = {
   selections: SingleSelection[];
 };
 
-type QueryType = string | QueryBundle;
-
 function wrapChoice(
   choice: string | ToolMetadataOnlyDescription,
 ): ToolMetadataOnlyDescription {
@@ -22,21 +21,13 @@ function wrapChoice(
   }
 }
 
-function wrapQuery(query: QueryType): QueryBundle {
-  if (typeof query === "string") {
-    return { queryStr: query };
-  }
-
-  return query;
-}
-
 type MetadataType = string | ToolMetadataOnlyDescription;
 
 export abstract class BaseSelector extends PromptMixin {
-  async select(choices: MetadataType[], query: QueryType) {
-    const metadatas = choices.map((choice) => wrapChoice(choice));
-    const queryBundle = wrapQuery(query);
-    return await this._select(metadatas, queryBundle);
+  async select(choices: MetadataType[], query: string | QueryBundle) {
+    const metadata = choices.map((choice) => wrapChoice(choice));
+    const queryBundle = toQueryBundle(query);
+    return await this._select(metadata, queryBundle);
   }
 
   abstract _select(
diff --git a/packages/llamaindex/src/selectors/llmSelectors.ts b/packages/llamaindex/src/selectors/llmSelectors.ts
index 242b9973a..e73966bfa 100644
--- a/packages/llamaindex/src/selectors/llmSelectors.ts
+++ b/packages/llamaindex/src/selectors/llmSelectors.ts
@@ -1,4 +1,5 @@
 import type { LLM } from "@llamaindex/core/llms";
+import { extractText } from "@llamaindex/core/utils";
 import type { Answer } from "../outputParsers/selectors.js";
 import { SelectionOutputParser } from "../outputParsers/selectors.js";
 import type {
@@ -88,7 +89,7 @@ export class LLMMultiSelector extends BaseSelector {
     const prompt = this.prompt(
       choicesText.length,
       choicesText,
-      query.queryStr,
+      extractText(query.query),
       this.maxOutputs,
     );
 
@@ -152,7 +153,11 @@ export class LLMSingleSelector extends BaseSelector {
   ): Promise<SelectorResult> {
     const choicesText = buildChoicesText(choices);
 
-    const prompt = this.prompt(choicesText.length, choicesText, query.queryStr);
+    const prompt = this.prompt(
+      choicesText.length,
+      choicesText,
+      extractText(query.query),
+    );
 
     const formattedPrompt = this.outputParser.format(prompt);
 
diff --git a/packages/llamaindex/src/synthesizers/builders.ts b/packages/llamaindex/src/synthesizers/builders.ts
index 901b6728a..f5ab12e81 100644
--- a/packages/llamaindex/src/synthesizers/builders.ts
+++ b/packages/llamaindex/src/synthesizers/builders.ts
@@ -1,5 +1,6 @@
 import type { LLM } from "@llamaindex/core/llms";
-import { streamConverter } from "@llamaindex/core/utils";
+import { extractText, streamConverter } from "@llamaindex/core/utils";
+import { toQueryBundle } from "../internal/utils.js";
 import type {
   RefinePrompt,
   SimplePrompt,
@@ -61,7 +62,7 @@ export class SimpleResponseBuilder implements ResponseBuilder {
     AsyncIterable<string> | string
   > {
     const input = {
-      query,
+      query: extractText(toQueryBundle(query).query),
       context: textChunks.join("\n\n"),
     };
 
@@ -142,14 +143,14 @@ export class Refine extends PromptMixin implements ResponseBuilder {
       const lastChunk = i === textChunks.length - 1;
       if (!response) {
         response = await this.giveResponseSingle(
-          query,
+          extractText(toQueryBundle(query).query),
           chunk,
           !!stream && lastChunk,
         );
       } else {
         response = await this.refineResponseSingle(
           response as string,
-          query,
+          extractText(toQueryBundle(query).query),
           chunk,
           !!stream && lastChunk,
         );
@@ -254,9 +255,15 @@ export class CompactAndRefine extends Refine {
     AsyncIterable<string> | string
   > {
     const textQATemplate: SimplePrompt = (input) =>
-      this.textQATemplate({ ...input, query: query });
+      this.textQATemplate({
+        ...input,
+        query: extractText(toQueryBundle(query).query),
+      });
     const refineTemplate: SimplePrompt = (input) =>
-      this.refineTemplate({ ...input, query: query });
+      this.refineTemplate({
+        ...input,
+        query: extractText(toQueryBundle(query).query),
+      });
 
     const maxPrompt = getBiggestPrompt([textQATemplate, refineTemplate]);
     const newTexts = this.promptHelper.repack(maxPrompt, textChunks);
@@ -335,7 +342,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder {
       const params = {
         prompt: this.summaryTemplate({
           context: packedTextChunks[0],
-          query,
+          query: extractText(toQueryBundle(query).query),
         }),
       };
       if (stream) {
@@ -349,7 +356,7 @@ export class TreeSummarize extends PromptMixin implements ResponseBuilder {
           this.llm.complete({
             prompt: this.summaryTemplate({
               context: chunk,
-              query,
+              query: extractText(toQueryBundle(query).query),
             }),
           }),
         ),
diff --git a/packages/llamaindex/src/types.ts b/packages/llamaindex/src/types.ts
index 19d697e3c..66cf1c5f7 100644
--- a/packages/llamaindex/src/types.ts
+++ b/packages/llamaindex/src/types.ts
@@ -1,7 +1,7 @@
 /**
  * Top level types to avoid circular dependencies
  */
-import type { ToolMetadata } from "@llamaindex/core/llms";
+import type { MessageContent, ToolMetadata } from "@llamaindex/core/llms";
 import type { EngineResponse } from "./EngineResponse.js";
 
 /**
@@ -52,16 +52,15 @@ export interface StructuredOutput<T> {
 
 export type ToolMetadataOnlyDescription = Pick<ToolMetadata, "description">;
 
-export class QueryBundle {
-  queryStr: string;
-
-  constructor(queryStr: string) {
-    this.queryStr = queryStr;
-  }
-
-  toString(): string {
-    return this.queryStr;
-  }
-}
+/**
+ * @link https://docs.llamaindex.ai/en/stable/api_reference/schema/?h=querybundle#llama_index.core.schema.QueryBundle
+ *
+ *  We don't have `image_path` here, because it is included in the `query` field.
+ */
+export type QueryBundle = {
+  query: string | MessageContent;
+  customEmbedding?: string[];
+  embeddings?: number[];
+};
 
 export type UUID = `${string}-${string}-${string}-${string}-${string}`;
-- 
GitLab