From 8e26f753b76a97344290e5cff21ff978cf4500ad Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Fri, 24 May 2024 23:08:20 +0800 Subject: [PATCH] feat: Add retrieval for images using multi-modal messages (#870) --- .changeset/hip-cycles-cheer.md | 5 ++ .../fixtures/embeddings/OpenAIEmbedding.ts | 9 +++- packages/core/src/Retriever.ts | 3 +- .../core/src/callbacks/CallbackManager.ts | 3 +- .../core/src/cloud/LlamaCloudRetriever.ts | 7 +-- packages/core/src/embeddings/ClipEmbedding.ts | 4 -- .../core/src/embeddings/GeminiEmbedding.ts | 4 -- .../src/embeddings/HuggingFaceEmbedding.ts | 4 -- .../core/src/embeddings/MistralAIEmbedding.ts | 4 -- .../src/embeddings/MultiModalEmbedding.ts | 16 ++++++ .../core/src/embeddings/OpenAIEmbedding.ts | 9 ---- packages/core/src/embeddings/types.ts | 13 ++++- .../engines/chat/DefaultContextGenerator.ts | 9 +++- packages/core/src/indices/keyword/index.ts | 3 +- packages/core/src/indices/summary/index.ts | 3 +- .../core/src/indices/vectorStore/index.ts | 54 +++++++++---------- packages/core/src/llm/ollama.ts | 4 -- packages/core/src/llm/types.ts | 4 +- packages/core/src/llm/utils.ts | 32 ++++++++++- .../postprocessors/rerankers/CohereRerank.ts | 6 ++- .../rerankers/JinaAIReranker.ts | 6 ++- packages/core/src/postprocessors/types.ts | 3 +- 22 files changed, 127 insertions(+), 78 deletions(-) create mode 100644 .changeset/hip-cycles-cheer.md diff --git a/.changeset/hip-cycles-cheer.md b/.changeset/hip-cycles-cheer.md new file mode 100644 index 000000000..513674171 --- /dev/null +++ b/.changeset/hip-cycles-cheer.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add retrieval for images using multi-modal messages diff --git a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts index eec0bdbed..85c2963d1 100644 --- a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts +++ b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts @@ -1,9 +1,14 @@ -import { BaseNode, SimilarityType, type BaseEmbedding } from "llamaindex"; +import { + BaseNode, + SimilarityType, + type BaseEmbedding, + type MessageContentDetail, +} from "llamaindex"; export class OpenAIEmbedding implements BaseEmbedding { embedBatchSize = 512; - async getQueryEmbedding(query: string) { + async getQueryEmbedding(query: MessageContentDetail) { return [0]; } diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index bc836527b..b37061d7d 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,8 +1,9 @@ import type { NodeWithScore } from "./Node.js"; import type { ServiceContext } from "./ServiceContext.js"; +import type { MessageContent } from "./index.edge.js"; export type RetrieveParams = { - query: string; + query: MessageContent; preFilters?: unknown; }; diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts index 3f8eae0d2..6e0940da5 100644 --- a/packages/core/src/callbacks/CallbackManager.ts +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -12,6 +12,7 @@ import type { LLMStreamEvent, LLMToolCallEvent, LLMToolResultEvent, + MessageContent, RetrievalEndEvent, RetrievalStartEvent, } from "../llm/types.js"; @@ -99,7 +100,7 @@ export interface StreamCallbackResponse { } export interface RetrievalCallbackResponse { - query: string; + query: MessageContent; nodes: NodeWithScore[]; } diff --git a/packages/core/src/cloud/LlamaCloudRetriever.ts b/packages/core/src/cloud/LlamaCloudRetriever.ts index 6f3bb745d..7f01c57ea 100644 --- a/packages/core/src/cloud/LlamaCloudRetriever.ts +++ b/packages/core/src/cloud/LlamaCloudRetriever.ts @@ -2,8 +2,9 @@ import type { PlatformApi, PlatformApiClient } from "@llamaindex/cloud"; import type { NodeWithScore } from "../Node.js"; import { ObjectType, jsonToNode } from "../Node.js"; import type { BaseRetriever, RetrieveParams } from "../Retriever.js"; -import { Settings } from "../Settings.js"; import { wrapEventCaller } from "../internal/context/EventCaller.js"; +import { getCallbackManager } from "../internal/settings/CallbackManager.js"; +import { extractText } from "../llm/utils.js"; import type { ClientParams, CloudConstructorParams } from "./types.js"; import { DEFAULT_PROJECT_NAME } from "./types.js"; import { getClient } from "./utils.js"; @@ -70,13 +71,13 @@ export class LlamaCloudRetriever implements BaseRetriever { await this.getClient() ).pipeline.runSearch(pipelines[0].id, { ...this.retrieveParams, - query, + query: extractText(query), searchFilters: preFilters as Record<string, unknown[]>, }); const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes); - Settings.callbackManager.dispatchEvent("retrieve", { + getCallbackManager().dispatchEvent("retrieve", { query, nodes, }); diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts index f4f108291..dd8f3c529 100644 --- a/packages/core/src/embeddings/ClipEmbedding.ts +++ b/packages/core/src/embeddings/ClipEmbedding.ts @@ -87,8 +87,4 @@ export class ClipEmbedding extends MultiModalEmbedding { const { text_embeds } = await (await this.getTextModel())(textInputs); return text_embeds.data; } - - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getTextEmbedding(query); - } } diff --git a/packages/core/src/embeddings/GeminiEmbedding.ts b/packages/core/src/embeddings/GeminiEmbedding.ts index fad773942..f08fe0619 100644 --- a/packages/core/src/embeddings/GeminiEmbedding.ts +++ b/packages/core/src/embeddings/GeminiEmbedding.ts @@ -36,8 +36,4 @@ export class GeminiEmbedding extends BaseEmbedding { getTextEmbedding(text: string): Promise<number[]> { return this.getEmbedding(text); } - - getQueryEmbedding(query: string): Promise<number[]> { - return this.getTextEmbedding(query); - } } diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts index d889aaeeb..dce8b0208 100644 --- a/packages/core/src/embeddings/HuggingFaceEmbedding.ts +++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts @@ -45,8 +45,4 @@ export class HuggingFaceEmbedding extends BaseEmbedding { const output = await extractor(text, { pooling: "mean", normalize: true }); return Array.from(output.data); } - - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getTextEmbedding(query); - } } diff --git a/packages/core/src/embeddings/MistralAIEmbedding.ts b/packages/core/src/embeddings/MistralAIEmbedding.ts index fdef08085..e1fb3f5ca 100644 --- a/packages/core/src/embeddings/MistralAIEmbedding.ts +++ b/packages/core/src/embeddings/MistralAIEmbedding.ts @@ -30,8 +30,4 @@ export class MistralAIEmbedding extends BaseEmbedding { async getTextEmbedding(text: string): Promise<number[]> { return this.getMistralAIEmbedding(text); } - - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getMistralAIEmbedding(query); - } } diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index 0dbdd5390..edf2c4439 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -6,6 +6,8 @@ import { type BaseNode, type ImageType, } from "../Node.js"; +import type { MessageContentDetail } from "../llm/types.js"; +import { extractImage, extractSingleText } from "../llm/utils.js"; import { BaseEmbedding, batchEmbeddings } from "./types.js"; /* @@ -52,4 +54,18 @@ export abstract class MultiModalEmbedding extends BaseEmbedding { return nodes; } + + async getQueryEmbedding( + query: MessageContentDetail, + ): Promise<number[] | null> { + const image = extractImage(query); + if (image) { + return await this.getImageEmbedding(image); + } + const text = extractSingleText(query); + if (text) { + return await this.getTextEmbedding(text); + } + return null; + } } diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts index 42037692c..a3460616e 100644 --- a/packages/core/src/embeddings/OpenAIEmbedding.ts +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -133,13 +133,4 @@ export class OpenAIEmbedding extends BaseEmbedding { async getTextEmbedding(text: string): Promise<number[]> { return (await this.getOpenAIEmbedding([text]))[0]; } - - /** - * Get embeddings for a query - * @param texts - * @param options - */ - async getQueryEmbedding(query: string): Promise<number[]> { - return (await this.getOpenAIEmbedding([query]))[0]; - } } diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts index 67d06940a..c0a237854 100644 --- a/packages/core/src/embeddings/types.ts +++ b/packages/core/src/embeddings/types.ts @@ -1,6 +1,8 @@ import type { BaseNode } from "../Node.js"; import { MetadataMode } from "../Node.js"; import type { TransformComponent } from "../ingestion/types.js"; +import type { MessageContentDetail } from "../llm/types.js"; +import { extractSingleText } from "../llm/utils.js"; import { SimilarityType, similarity } from "./utils.js"; const DEFAULT_EMBED_BATCH_SIZE = 10; @@ -19,7 +21,16 @@ export abstract class BaseEmbedding implements TransformComponent { } abstract getTextEmbedding(text: string): Promise<number[]>; - abstract getQueryEmbedding(query: string): Promise<number[]>; + + async getQueryEmbedding( + query: MessageContentDetail, + ): Promise<number[] | null> { + const text = extractSingleText(query); + if (text) { + return await this.getTextEmbedding(text); + } + return null; + } /** * Optionally override this method to retrieve multiple embeddings in a single request diff --git a/packages/core/src/engines/chat/DefaultContextGenerator.ts b/packages/core/src/engines/chat/DefaultContextGenerator.ts index 7e2b8e110..362a8335d 100644 --- a/packages/core/src/engines/chat/DefaultContextGenerator.ts +++ b/packages/core/src/engines/chat/DefaultContextGenerator.ts @@ -2,6 +2,7 @@ import type { NodeWithScore, TextNode } from "../../Node.js"; import type { ContextSystemPrompt } from "../../Prompt.js"; import { defaultContextSystemPrompt } from "../../Prompt.js"; import type { BaseRetriever } from "../../Retriever.js"; +import type { MessageContent } from "../../llm/types.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import { PromptMixin } from "../../prompts/index.js"; import type { Context, ContextGenerator } from "./types.js"; @@ -41,7 +42,10 @@ export class DefaultContextGenerator } } - private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) { + private async applyNodePostprocessors( + nodes: NodeWithScore[], + query: MessageContent, + ) { let nodesWithScore = nodes; for (const postprocessor of this.nodePostprocessors) { @@ -54,7 +58,7 @@ export class DefaultContextGenerator return nodesWithScore; } - async generate(message: string): Promise<Context> { + async generate(message: MessageContent): Promise<Context> { const sourceNodesWithScore = await this.retriever.retrieve({ query: message, }); @@ -64,6 +68,7 @@ export class DefaultContextGenerator message, ); + // TODO: also use retrieved image nodes in context return { message: { content: this.contextSystemPrompt({ diff --git a/packages/core/src/indices/keyword/index.ts b/packages/core/src/indices/keyword/index.ts index 06d59fed6..7ff2c8255 100644 --- a/packages/core/src/indices/keyword/index.ts +++ b/packages/core/src/indices/keyword/index.ts @@ -29,6 +29,7 @@ import { import { llmFromSettingsOrContext } from "../../Settings.js"; import type { LLM } from "../../llm/types.js"; +import { extractText } from "../../llm/utils.js"; export interface KeywordIndexOptions { nodes?: BaseNode[]; @@ -85,7 +86,7 @@ abstract class BaseKeywordTableRetriever implements BaseRetriever { abstract getKeywords(query: string): Promise<string[]>; async retrieve({ query }: RetrieveParams): Promise<NodeWithScore[]> { - const keywords = await this.getKeywords(query); + const keywords = await this.getKeywords(extractText(query)); const chunkIndicesCount: { [key: string]: number } = {}; const filteredKeywords = keywords.filter((keyword) => this.indexStruct.table.has(keyword), diff --git a/packages/core/src/indices/summary/index.ts b/packages/core/src/indices/summary/index.ts index 3a29cd0ac..317ee00c0 100644 --- a/packages/core/src/indices/summary/index.ts +++ b/packages/core/src/indices/summary/index.ts @@ -11,6 +11,7 @@ import { } from "../../Settings.js"; import { RetrieverQueryEngine } from "../../engines/query/index.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; +import { extractText } from "../../llm/utils.js"; import type { BaseNodePostprocessor } from "../../postprocessors/index.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -343,7 +344,7 @@ export class SummaryIndexLLMRetriever implements BaseRetriever { const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch); const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); - const input = { context: fmtBatchStr, query: query }; + const input = { context: fmtBatchStr, query: extractText(query) }; const llm = llmFromSettingsOrContext(this.serviceContext); diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts index f185cac2f..a8b1e94e3 100644 --- a/packages/core/src/indices/vectorStore/index.ts +++ b/packages/core/src/indices/vectorStore/index.ts @@ -23,6 +23,7 @@ import { } from "../../ingestion/strategies/index.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import { getCallbackManager } from "../../internal/settings/CallbackManager.js"; +import type { MessageContent } from "../../llm/types.js"; import type { BaseNodePostprocessor } from "../../postprocessors/types.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -30,7 +31,6 @@ import type { MetadataFilters, VectorStore, VectorStoreByType, - VectorStoreQuery, VectorStoreQueryResult, } from "../../storage/index.js"; import type { BaseIndexStore } from "../../storage/indexStore/types.js"; @@ -422,10 +422,9 @@ export class VectorIndexRetriever implements BaseRetriever { let nodesWithScores: NodeWithScore[] = []; for (const type in vectorStores) { - // TODO: add retrieval by using an image as query const vectorStore: VectorStore = vectorStores[type as ModalityType]!; nodesWithScores = nodesWithScores.concat( - await this.textRetrieve( + await this.retrieveQuery( query, type as ModalityType, vectorStore, @@ -447,36 +446,33 @@ export class VectorIndexRetriever implements BaseRetriever { return nodesWithScores; } - protected async textRetrieve( - query: string, + protected async retrieveQuery( + query: MessageContent, type: ModalityType, vectorStore: VectorStore, preFilters?: MetadataFilters, ): Promise<NodeWithScore[]> { - const q = await this.buildVectorStoreQuery( - this.index.embedModel ?? vectorStore.embedModel, - query, - this.topK[type], - preFilters, - ); - const result = await vectorStore.query(q); - return this.buildNodeListFromQueryResult(result); - } - - protected async buildVectorStoreQuery( - embedModel: BaseEmbedding, - query: string, - similarityTopK: number, - preFilters?: MetadataFilters, - ): Promise<VectorStoreQuery> { - const queryEmbedding = await embedModel.getQueryEmbedding(query); - - return { - queryEmbedding, - mode: VectorStoreQueryMode.DEFAULT, - similarityTopK, - filters: preFilters ?? undefined, - }; + // convert string message to multi-modal format + if (typeof query === "string") { + query = [{ type: "text", text: query }]; + } + // overwrite embed model if specified, otherwise use the one from the vector store + const embedModel = this.index.embedModel ?? vectorStore.embedModel; + let nodes: NodeWithScore[] = []; + // query each content item (e.g. text or image) separately + for (const item of query) { + const queryEmbedding = await embedModel.getQueryEmbedding(item); + if (queryEmbedding) { + const result = await vectorStore.query({ + queryEmbedding, + mode: VectorStoreQueryMode.DEFAULT, + similarityTopK: this.topK[type], + filters: preFilters ?? undefined, + }); + nodes = nodes.concat(this.buildNodeListFromQueryResult(result)); + } + } + return nodes; } protected buildNodeListFromQueryResult(result: VectorStoreQueryResult) { diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts index 7343a2a4a..179cb7e58 100644 --- a/packages/core/src/llm/ollama.ts +++ b/packages/core/src/llm/ollama.ts @@ -191,10 +191,6 @@ export class Ollama return this.getEmbedding(text); } - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getEmbedding(query); - } - // Inherited from OllamaBase push( diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index a2debd790..4060298c9 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -4,10 +4,10 @@ import type { BaseEvent } from "../internal/type.js"; import type { BaseTool, JSONObject, ToolOutput, UUID } from "../types.js"; export type RetrievalStartEvent = BaseEvent<{ - query: string; + query: MessageContent; }>; export type RetrievalEndEvent = BaseEvent<{ - query: string; + query: MessageContent; nodes: NodeWithScore[]; }>; export type LLMStartEvent = BaseEvent<{ diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 28240fff6..4a39e1584 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,4 +1,5 @@ import { AsyncLocalStorage, randomUUID } from "@llamaindex/env"; +import type { ImageType } from "../Node.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { ChatResponse, @@ -6,6 +7,7 @@ import type { LLM, LLMChat, MessageContent, + MessageContentDetail, MessageContentTextDetail, } from "./types.js"; @@ -62,7 +64,7 @@ export async function* streamReducer<S, D>(params: { export function extractText(message: MessageContent): string { if (typeof message !== "string" && !Array.isArray(message)) { console.warn( - "extractText called with non-string message, this is likely a bug.", + "extractText called with non-MessageContent message, this is likely a bug.", ); return `${message}`; } else if (typeof message !== "string" && Array.isArray(message)) { @@ -77,6 +79,34 @@ export function extractText(message: MessageContent): string { } } +/** + * Extracts a single text from a multi-modal message content + * + * @param message The message to extract images from. + * @returns The extracted images + */ +export function extractSingleText( + message: MessageContentDetail, +): string | null { + if (message.type === "text") { + return message.text; + } + return null; +} + +/** + * Extracts an image from a multi-modal message content + * + * @param message The message to extract images from. + * @returns The extracted images + */ +export function extractImage(message: MessageContentDetail): ImageType | null { + if (message.type === "image_url") { + return new URL(message.image_url.url); + } + return null; +} + export const extractDataUrlComponents = ( dataUrl: string, ): { diff --git a/packages/core/src/postprocessors/rerankers/CohereRerank.ts b/packages/core/src/postprocessors/rerankers/CohereRerank.ts index 5cd04e19a..c57c48045 100644 --- a/packages/core/src/postprocessors/rerankers/CohereRerank.ts +++ b/packages/core/src/postprocessors/rerankers/CohereRerank.ts @@ -2,6 +2,8 @@ import { CohereClient } from "cohere-ai"; import type { NodeWithScore } from "../../Node.js"; import { MetadataMode } from "../../Node.js"; +import type { MessageContent } from "../../llm/types.js"; +import { extractText } from "../../llm/utils.js"; import type { BaseNodePostprocessor } from "../types.js"; type CohereRerankOptions = { @@ -46,7 +48,7 @@ export class CohereRerank implements BaseNodePostprocessor { */ async postprocessNodes( nodes: NodeWithScore[], - query?: string, + query?: MessageContent, ): Promise<NodeWithScore[]> { if (this.client === null) { throw new Error("CohereRerank client is null"); @@ -61,7 +63,7 @@ export class CohereRerank implements BaseNodePostprocessor { } const results = await this.client.rerank({ - query, + query: extractText(query), model: this.model, topN: this.topN, documents: nodes.map((n) => n.node.getContent(MetadataMode.ALL)), diff --git a/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts b/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts index 68bb8af00..1db612e5b 100644 --- a/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts +++ b/packages/core/src/postprocessors/rerankers/JinaAIReranker.ts @@ -1,6 +1,8 @@ import { getEnv } from "@llamaindex/env"; import type { NodeWithScore } from "../../Node.js"; import { MetadataMode } from "../../Node.js"; +import type { MessageContent } from "../../llm/types.js"; +import { extractText } from "../../llm/utils.js"; import type { BaseNodePostprocessor } from "../types.js"; interface JinaAIRerankerResult { @@ -62,7 +64,7 @@ export class JinaAIReranker implements BaseNodePostprocessor { async postprocessNodes( nodes: NodeWithScore[], - query?: string, + query?: MessageContent, ): Promise<NodeWithScore[]> { if (nodes.length === 0) { return []; @@ -73,7 +75,7 @@ export class JinaAIReranker implements BaseNodePostprocessor { } const documents = nodes.map((n) => n.node.getContent(MetadataMode.ALL)); - const results = await this.rerank(query, documents, this.topN); + const results = await this.rerank(extractText(query), documents, this.topN); const newNodes: NodeWithScore[] = []; for (const result of results) { diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts index a45ccabf7..bb04db34a 100644 --- a/packages/core/src/postprocessors/types.ts +++ b/packages/core/src/postprocessors/types.ts @@ -1,4 +1,5 @@ import type { NodeWithScore } from "../Node.js"; +import type { MessageContent } from "../llm/types.js"; export interface BaseNodePostprocessor { /** @@ -9,6 +10,6 @@ export interface BaseNodePostprocessor { */ postprocessNodes( nodes: NodeWithScore[], - query?: string, + query?: MessageContent, ): Promise<NodeWithScore[]>; } -- GitLab