From 8e26f753b76a97344290e5cff21ff978cf4500ad Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Fri, 24 May 2024 23:08:20 +0800
Subject: [PATCH] feat: Add retrieval for images using multi-modal messages
 (#870)

---
 .changeset/hip-cycles-cheer.md                |  5 ++
 .../fixtures/embeddings/OpenAIEmbedding.ts    |  9 +++-
 packages/core/src/Retriever.ts                |  3 +-
 .../core/src/callbacks/CallbackManager.ts     |  3 +-
 .../core/src/cloud/LlamaCloudRetriever.ts     |  7 +--
 packages/core/src/embeddings/ClipEmbedding.ts |  4 --
 .../core/src/embeddings/GeminiEmbedding.ts    |  4 --
 .../src/embeddings/HuggingFaceEmbedding.ts    |  4 --
 .../core/src/embeddings/MistralAIEmbedding.ts |  4 --
 .../src/embeddings/MultiModalEmbedding.ts     | 16 ++++++
 .../core/src/embeddings/OpenAIEmbedding.ts    |  9 ----
 packages/core/src/embeddings/types.ts         | 13 ++++-
 .../engines/chat/DefaultContextGenerator.ts   |  9 +++-
 packages/core/src/indices/keyword/index.ts    |  3 +-
 packages/core/src/indices/summary/index.ts    |  3 +-
 .../core/src/indices/vectorStore/index.ts     | 54 +++++++++----------
 packages/core/src/llm/ollama.ts               |  4 --
 packages/core/src/llm/types.ts                |  4 +-
 packages/core/src/llm/utils.ts                | 32 ++++++++++-
 .../postprocessors/rerankers/CohereRerank.ts  |  6 ++-
 .../rerankers/JinaAIReranker.ts               |  6 ++-
 packages/core/src/postprocessors/types.ts     |  3 +-
 22 files changed, 127 insertions(+), 78 deletions(-)
 create mode 100644 .changeset/hip-cycles-cheer.md

diff --git a/.changeset/hip-cycles-cheer.md b/.changeset/hip-cycles-cheer.md
new file mode 100644
index 000000000..513674171
--- /dev/null
+++ b/.changeset/hip-cycles-cheer.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Add retrieval for images using multi-modal messages
diff --git a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
index eec0bdbed..85c2963d1 100644
--- a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
+++ b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
@@ -1,9 +1,14 @@
-import { BaseNode, SimilarityType, type BaseEmbedding } from "llamaindex";
+import {
+  BaseNode,
+  SimilarityType,
+  type BaseEmbedding,
+  type MessageContentDetail,
+} from "llamaindex";
 
 export class OpenAIEmbedding implements BaseEmbedding {
   embedBatchSize = 512;
 
-  async getQueryEmbedding(query: string) {
+  async getQueryEmbedding(query: MessageContentDetail) {
     return [0];
   }
 
diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts
index bc836527b..b37061d7d 100644
--- a/packages/core/src/Retriever.ts
+++ b/packages/core/src/Retriever.ts
@@ -1,8 +1,9 @@
 import type { NodeWithScore } from "./Node.js";
 import type { ServiceContext } from "./ServiceContext.js";
+import type { MessageContent } from "./index.edge.js";
 
 export type RetrieveParams = {
-  query: string;
+  query: MessageContent;
   preFilters?: unknown;
 };
 
diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts
index 3f8eae0d2..6e0940da5 100644
--- a/packages/core/src/callbacks/CallbackManager.ts
+++ b/packages/core/src/callbacks/CallbackManager.ts
@@ -12,6 +12,7 @@ import type {
   LLMStreamEvent,
   LLMToolCallEvent,
   LLMToolResultEvent,
+  MessageContent,
   RetrievalEndEvent,
   RetrievalStartEvent,
 } from "../llm/types.js";
@@ -99,7 +100,7 @@ export interface StreamCallbackResponse {
 }
 
 export interface RetrievalCallbackResponse {
-  query: string;
+  query: MessageContent;
   nodes: NodeWithScore[];
 }
 
diff --git a/packages/core/src/cloud/LlamaCloudRetriever.ts b/packages/core/src/cloud/LlamaCloudRetriever.ts
index 6f3bb745d..7f01c57ea 100644
--- a/packages/core/src/cloud/LlamaCloudRetriever.ts
+++ b/packages/core/src/cloud/LlamaCloudRetriever.ts
@@ -2,8 +2,9 @@ import type { PlatformApi, PlatformApiClient } from "@llamaindex/cloud";
 import type { NodeWithScore } from "../Node.js";
 import { ObjectType, jsonToNode } from "../Node.js";
 import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
-import { Settings } from "../Settings.js";
 import { wrapEventCaller } from "../internal/context/EventCaller.js";
+import { getCallbackManager } from "../internal/settings/CallbackManager.js";
+import { extractText } from "../llm/utils.js";
 import type { ClientParams, CloudConstructorParams } from "./types.js";
 import { DEFAULT_PROJECT_NAME } from "./types.js";
 import { getClient } from "./utils.js";
@@ -70,13 +71,13 @@ export class LlamaCloudRetriever implements BaseRetriever {
       await this.getClient()
     ).pipeline.runSearch(pipelines[0].id, {
       ...this.retrieveParams,
-      query,
+      query: extractText(query),
       searchFilters: preFilters as Record<string, unknown[]>,
     });
 
     const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);
 
-    Settings.callbackManager.dispatchEvent("retrieve", {
+    getCallbackManager().dispatchEvent("retrieve", {
       query,
       nodes,
     });
diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts
index f4f108291..dd8f3c529 100644
--- a/packages/core/src/embeddings/ClipEmbedding.ts
+++ b/packages/core/src/embeddings/ClipEmbedding.ts
@@ -87,8 +87,4 @@ export class ClipEmbedding extends MultiModalEmbedding {
     const { text_embeds } = await (await this.getTextModel())(textInputs);
     return text_embeds.data;
   }
-
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getTextEmbedding(query);
-  }
 }
diff --git a/packages/core/src/embeddings/GeminiEmbedding.ts b/packages/core/src/embeddings/GeminiEmbedding.ts
index fad773942..f08fe0619 100644
--- a/packages/core/src/embeddings/GeminiEmbedding.ts
+++ b/packages/core/src/embeddings/GeminiEmbedding.ts
@@ -36,8 +36,4 @@ export class GeminiEmbedding extends BaseEmbedding {
   getTextEmbedding(text: string): Promise<number[]> {
     return this.getEmbedding(text);
   }
-
-  getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getTextEmbedding(query);
-  }
 }
diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
index d889aaeeb..dce8b0208 100644
--- a/packages/core/src/embeddings/HuggingFaceEmbedding.ts
+++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
@@ -45,8 +45,4 @@ export class HuggingFaceEmbedding extends BaseEmbedding {
     const output = await extractor(text, { pooling: "mean", normalize: true });
     return Array.from(output.data);
   }
-
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getTextEmbedding(query);
-  }
 }
diff --git a/packages/core/src/embeddings/MistralAIEmbedding.ts b/packages/core/src/embeddings/MistralAIEmbedding.ts
index fdef08085..e1fb3f5ca 100644
--- a/packages/core/src/embeddings/MistralAIEmbedding.ts
+++ b/packages/core/src/embeddings/MistralAIEmbedding.ts
@@ -30,8 +30,4 @@ export class MistralAIEmbedding extends BaseEmbedding {
   async getTextEmbedding(text: string): Promise<number[]> {
     return this.getMistralAIEmbedding(text);
   }
-
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getMistralAIEmbedding(query);
-  }
 }
diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts
index 0dbdd5390..edf2c4439 100644
--- a/packages/core/src/embeddings/MultiModalEmbedding.ts
+++ b/packages/core/src/embeddings/MultiModalEmbedding.ts
@@ -6,6 +6,8 @@ import {
   type BaseNode,
   type ImageType,
 } from "../Node.js";
+import type { MessageContentDetail } from "../llm/types.js";
+import { extractImage, extractSingleText } from "../llm/utils.js";
 import { BaseEmbedding, batchEmbeddings } from "./types.js";
 
 /*
@@ -52,4 +54,18 @@ export abstract class MultiModalEmbedding extends BaseEmbedding {
 
     return nodes;
   }
+
+  async getQueryEmbedding(
+    query: MessageContentDetail,
+  ): Promise<number[] | null> {
+    const image = extractImage(query);
+    if (image) {
+      return await this.getImageEmbedding(image);
+    }
+    const text = extractSingleText(query);
+    if (text) {
+      return await this.getTextEmbedding(text);
+    }
+    return null;
+  }
 }
diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts
index 42037692c..a3460616e 100644
--- a/packages/core/src/embeddings/OpenAIEmbedding.ts
+++ b/packages/core/src/embeddings/OpenAIEmbedding.ts
@@ -133,13 +133,4 @@ export class OpenAIEmbedding extends BaseEmbedding {
   async getTextEmbedding(text: string): Promise<number[]> {
     return (await this.getOpenAIEmbedding([text]))[0];
   }
-
-  /**
-   * Get embeddings for a query
-   * @param texts
-   * @param options
-   */
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return (await this.getOpenAIEmbedding([query]))[0];
-  }
 }
diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts
index 67d06940a..c0a237854 100644
--- a/packages/core/src/embeddings/types.ts
+++ b/packages/core/src/embeddings/types.ts
@@ -1,6 +1,8 @@
 import type { BaseNode } from "../Node.js";
 import { MetadataMode } from "../Node.js";
 import type { TransformComponent } from "../ingestion/types.js";
+import type { MessageContentDetail } from "../llm/types.js";
+import { extractSingleText } from "../llm/utils.js";
 import { SimilarityType, similarity } from "./utils.js";
 
 const DEFAULT_EMBED_BATCH_SIZE = 10;
@@ -19,7 +21,16 @@ export abstract class BaseEmbedding implements TransformComponent {
   }
 
   abstract getTextEmbedding(text: string): Promise<number[]>;
-  abstract getQueryEmbedding(query: string): Promise<number[]>;
+
+  async getQueryEmbedding(
+    query: MessageContentDetail,
+  ): Promise<number[] | null> {
+    const text = extractSingleText(query);
+    if (text) {
+      return await this.getTextEmbedding(text);
+    }
+    return null;
+  }
 
   /**
    * Optionally override this method to retrieve multiple embeddings in a single request
diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts
index 7e2b8e110..362a8335d 100644
--- a/packages/core/src/engines/chat/DefaultContextGenerator.ts
+++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts
@@ -2,6 +2,7 @@ import type { NodeWithScore, TextNode } from "../../Node.js";
 import type { ContextSystemPrompt } from "../../Prompt.js";
 import { defaultContextSystemPrompt } from "../../Prompt.js";
 import type { BaseRetriever } from "../../Retriever.js";
+import type { MessageContent } from "../../llm/types.js";
 import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
 import { PromptMixin } from "../../prompts/index.js";
 import type { Context, ContextGenerator } from "./types.js";
@@ -41,7 +42,10 @@ export class DefaultContextGenerator
     }
   }
 
-  private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
+  private async applyNodePostprocessors(
+    nodes: NodeWithScore[],
+    query: MessageContent,
+  ) {
     let nodesWithScore = nodes;
 
     for (const postprocessor of this.nodePostprocessors) {
@@ -54,7 +58,7 @@ export class DefaultContextGenerator
     return nodesWithScore;
   }
 
-  async generate(message: string): Promise<Context> {
+  async generate(message: MessageContent): Promise<Context> {
     const sourceNodesWithScore = await this.retriever.retrieve({
       query: message,
     });
@@ -64,6 +68,7 @@ export class DefaultContextGenerator
       message,
     );
 
+    // TODO: also use retrieved image nodes in context
     return {
       message: {
         content: this.contextSystemPrompt({
diff --git a/packages/core/src/indices/keyword/index.ts b/packages/core/src/indices/keyword/index.ts
index 06d59fed6..7ff2c8255 100644
--- a/packages/core/src/indices/keyword/index.ts
+++ b/packages/core/src/indices/keyword/index.ts
@@ -29,6 +29,7 @@ import {
 
 import { llmFromSettingsOrContext } from "../../Settings.js";
 import type { LLM } from "../../llm/types.js";
+import { extractText } from "../../llm/utils.js";
 
 export interface KeywordIndexOptions {
   nodes?: BaseNode[];
@@ -85,7 +86,7 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever {
   abstract getKeywords(query: string): Promise<string[]>;
 
   async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> {
-    const keywords = await this.getKeywords(query);
+    const keywords = await this.getKeywords(extractText(query));
     const chunkIndicesCount: { [key: string]: number } = {};
     const filteredKeywords = keywords.filter((keyword) =>
       this.indexStruct.table.has(keyword),
diff --git a/packages/core/src/indices/summary/index.ts b/packages/core/src/indices/summary/index.ts
index 3a29cd0ac..317ee00c0 100644
--- a/packages/core/src/indices/summary/index.ts
+++ b/packages/core/src/indices/summary/index.ts
@@ -11,6 +11,7 @@ import {
 } from "../../Settings.js";
 import { RetrieverQueryEngine } from "../../engines/query/index.js";
 import { wrapEventCaller } from "../../internal/context/EventCaller.js";
+import { extractText } from "../../llm/utils.js";
 import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
 import type { StorageContext } from "../../storage/StorageContext.js";
 import { storageContextFromDefaults } from "../../storage/StorageContext.js";
@@ -343,7 +344,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever {
       const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
 
       const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
-      const input = { context: fmtBatchStr, query: query };
+      const input = { context: fmtBatchStr, query: extractText(query) };
 
       const llm = llmFromSettingsOrContext(this.serviceContext);
 
diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts
index f185cac2f..a8b1e94e3 100644
--- a/packages/core/src/indices/vectorStore/index.ts
+++ b/packages/core/src/indices/vectorStore/index.ts
@@ -23,6 +23,7 @@ import {
 } from "../../ingestion/strategies/index.js";
 import { wrapEventCaller } from "../../internal/context/EventCaller.js";
 import { getCallbackManager } from "../../internal/settings/CallbackManager.js";
+import type { MessageContent } from "../../llm/types.js";
 import type { BaseNodePostprocessor } from "../../postprocessors/types.js";
 import type { StorageContext } from "../../storage/StorageContext.js";
 import { storageContextFromDefaults } from "../../storage/StorageContext.js";
@@ -30,7 +31,6 @@ import type {
   MetadataFilters,
   VectorStore,
   VectorStoreByType,
-  VectorStoreQuery,
   VectorStoreQueryResult,
 } from "../../storage/index.js";
 import type { BaseIndexStore } from "../../storage/indexStore/types.js";
@@ -422,10 +422,9 @@ export class VectorIndexRetriever implements BaseRetriever {
     let nodesWithScores: NodeWithScore[] = [];
 
     for (const type in vectorStores) {
-      // TODO: add retrieval by using an image as query
       const vectorStore: VectorStore = vectorStores[type as ModalityType]!;
       nodesWithScores = nodesWithScores.concat(
-        await this.textRetrieve(
+        await this.retrieveQuery(
           query,
           type as ModalityType,
           vectorStore,
@@ -447,36 +446,33 @@ export class VectorIndexRetriever implements BaseRetriever {
     return nodesWithScores;
   }
 
-  protected async textRetrieve(
-    query: string,
+  protected async retrieveQuery(
+    query: MessageContent,
     type: ModalityType,
     vectorStore: VectorStore,
     preFilters?: MetadataFilters,
   ): Promise<NodeWithScore[]> {
-    const q = await this.buildVectorStoreQuery(
-      this.index.embedModel ?? vectorStore.embedModel,
-      query,
-      this.topK[type],
-      preFilters,
-    );
-    const result = await vectorStore.query(q);
-    return this.buildNodeListFromQueryResult(result);
-  }
-
-  protected async buildVectorStoreQuery(
-    embedModel: BaseEmbedding,
-    query: string,
-    similarityTopK: number,
-    preFilters?: MetadataFilters,
-  ): Promise<VectorStoreQuery> {
-    const queryEmbedding = await embedModel.getQueryEmbedding(query);
-
-    return {
-      queryEmbedding,
-      mode: VectorStoreQueryMode.DEFAULT,
-      similarityTopK,
-      filters: preFilters ?? undefined,
-    };
+    // convert string message to multi-modal format
+    if (typeof query === "string") {
+      query = [{ type: "text", text: query }];
+    }
+    // overwrite embed model if specified, otherwise use the one from the vector store
+    const embedModel = this.index.embedModel ?? vectorStore.embedModel;
+    let nodes: NodeWithScore[] = [];
+    // query each content item (e.g. text or image) separately
+    for (const item of query) {
+      const queryEmbedding = await embedModel.getQueryEmbedding(item);
+      if (queryEmbedding) {
+        const result = await vectorStore.query({
+          queryEmbedding,
+          mode: VectorStoreQueryMode.DEFAULT,
+          similarityTopK: this.topK[type],
+          filters: preFilters ?? undefined,
+        });
+        nodes = nodes.concat(this.buildNodeListFromQueryResult(result));
+      }
+    }
+    return nodes;
   }
 
   protected buildNodeListFromQueryResult(result: VectorStoreQueryResult) {
diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts
index 7343a2a4a..179cb7e58 100644
--- a/packages/core/src/llm/ollama.ts
+++ b/packages/core/src/llm/ollama.ts
@@ -191,10 +191,6 @@ export class Ollama
     return this.getEmbedding(text);
   }
 
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getEmbedding(query);
-  }
-
   // Inherited from OllamaBase
 
   push(
diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts
index a2debd790..4060298c9 100644
--- a/packages/core/src/llm/types.ts
+++ b/packages/core/src/llm/types.ts
@@ -4,10 +4,10 @@ import type { BaseEvent } from "../internal/type.js";
 import type { BaseTool, JSONObject, ToolOutput, UUID } from "../types.js";
 
 export type RetrievalStartEvent = BaseEvent<{
-  query: string;
+  query: MessageContent;
 }>;
 export type RetrievalEndEvent = BaseEvent<{
-  query: string;
+  query: MessageContent;
   nodes: NodeWithScore[];
 }>;
 export type LLMStartEvent = BaseEvent<{
diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts
index 28240fff6..4a39e1584 100644
--- a/packages/core/src/llm/utils.ts
+++ b/packages/core/src/llm/utils.ts
@@ -1,4 +1,5 @@
 import { AsyncLocalStorage, randomUUID } from "@llamaindex/env";
+import type { ImageType } from "../Node.js";
 import { getCallbackManager } from "../internal/settings/CallbackManager.js";
 import type {
   ChatResponse,
@@ -6,6 +7,7 @@ import type {
   LLM,
   LLMChat,
   MessageContent,
+  MessageContentDetail,
   MessageContentTextDetail,
 } from "./types.js";
 
@@ -62,7 +64,7 @@ export async function* streamReducer<S, D>(params: {
 export function extractText(message: MessageContent): string {
   if (typeof message !== "string" && !Array.isArray(message)) {
     console.warn(
-      "extractText called with non-string message, this is likely a bug.",
+      "extractText called with non-MessageContent message, this is likely a bug.",
     );
     return `${message}`;
   } else if (typeof message !== "string" && Array.isArray(message)) {
@@ -77,6 +79,34 @@ export function extractText(message: MessageContent): string {
   }
 }
 
+/**
+ * Extracts a single text from a multi-modal message content
+ *
+ * @param message The message to extract images from.
+ * @returns The extracted images
+ */
+export function extractSingleText(
+  message: MessageContentDetail,
+): string | null {
+  if (message.type === "text") {
+    return message.text;
+  }
+  return null;
+}
+
+/**
+ * Extracts an image from a multi-modal message content
+ *
+ * @param message The message to extract images from.
+ * @returns The extracted images
+ */
+export function extractImage(message: MessageContentDetail): ImageType | null {
+  if (message.type === "image_url") {
+    return new URL(message.image_url.url);
+  }
+  return null;
+}
+
 export const extractDataUrlComponents = (
   dataUrl: string,
 ): {
diff --git a/packages/core/src/postprocessors/rerankers/CohereRerank.ts b/packages/core/src/postprocessors/rerankers/CohereRerank.ts
index 5cd04e19a..c57c48045 100644
--- a/packages/core/src/postprocessors/rerankers/CohereRerank.ts
+++ b/packages/core/src/postprocessors/rerankers/CohereRerank.ts
@@ -2,6 +2,8 @@ import { CohereClient } from "cohere-ai";
 
 import type { NodeWithScore } from "../../Node.js";
 import { MetadataMode } from "../../Node.js";
+import type { MessageContent } from "../../llm/types.js";
+import { extractText } from "../../llm/utils.js";
 import type { BaseNodePostprocessor } from "../types.js";
 
 type CohereRerankOptions = {
@@ -46,7 +48,7 @@ export class CohereRerank implements BaseNodePostprocessor {
    */
   async postprocessNodes(
     nodes: NodeWithScore[],
-    query?: string,
+    query?: MessageContent,
   ): Promise<NodeWithScore[]> {
     if (this.client === null) {
       throw new Error("CohereRerank client is null");
@@ -61,7 +63,7 @@ export class CohereRerank implements BaseNodePostprocessor {
     }
 
     const results = await this.client.rerank({
-      query,
+      query: extractText(query),
       model: this.model,
       topN: this.topN,
       documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)),
diff --git a/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts b/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts
index 68bb8af00..1db612e5b 100644
--- a/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts
+++ b/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts
@@ -1,6 +1,8 @@
 import { getEnv } from "@llamaindex/env";
 import type { NodeWithScore } from "../../Node.js";
 import { MetadataMode } from "../../Node.js";
+import type { MessageContent } from "../../llm/types.js";
+import { extractText } from "../../llm/utils.js";
 import type { BaseNodePostprocessor } from "../types.js";
 
 interface JinaAIRerankerResult {
@@ -62,7 +64,7 @@ export class JinaAIReranker implements BaseNodePostprocessor {
 
   async postprocessNodes(
     nodes: NodeWithScore[],
-    query?: string,
+    query?: MessageContent,
   ): Promise<NodeWithScore[]> {
     if (nodes.length === 0) {
       return [];
@@ -73,7 +75,7 @@ export class JinaAIReranker implements BaseNodePostprocessor {
     }
 
     const documents = nodes.map((n) => n.node.getContent(MetadataMode.ALL));
-    const results = await this.rerank(query, documents, this.topN);
+    const results = await this.rerank(extractText(query), documents, this.topN);
     const newNodes: NodeWithScore[] = [];
 
     for (const result of results) {
diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts
index a45ccabf7..bb04db34a 100644
--- a/packages/core/src/postprocessors/types.ts
+++ b/packages/core/src/postprocessors/types.ts
@@ -1,4 +1,5 @@
 import type { NodeWithScore } from "../Node.js";
+import type { MessageContent } from "../llm/types.js";
 
 export interface BaseNodePostprocessor {
   /**
@@ -9,6 +10,6 @@ export interface BaseNodePostprocessor {
    */
   postprocessNodes(
     nodes: NodeWithScore[],
-    query?: string,
+    query?: MessageContent,
   ): Promise<NodeWithScore[]>;
 }
-- 
GitLab