Skip to content
Snippets Groups Projects
Commit 76831864 authored by Marcus Schiesser's avatar Marcus Schiesser
Browse files

feat: added MultiModelVectorStoreIndex

parent 22ff7da4
No related branches found
No related tags found
No related merge requests found
...@@ -229,13 +229,16 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> { ...@@ -229,13 +229,16 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> {
} }
} }
// export class ImageNode extends TextNode { export type ImageType = string | Blob | URL;
// image: string = "";
// getType(): ObjectType { export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
// return ObjectType.IMAGE; image?: ImageType; // base64 encoded image string
// } textEmbedding?: number[]; // Assuming text embedding is an array of numbers
// }
static getType(): string {
return ObjectType.IMAGE;
}
}
export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> { export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> {
indexId: string = ""; indexId: string = "";
......
import { ImageType } from "../Node";
import { MultiModalEmbedding } from "./MultiModalEmbedding"; import { MultiModalEmbedding } from "./MultiModalEmbedding";
import { ImageType, readImage } from "./utils"; import { readImage } from "./utils";
export enum ClipEmbeddingModelType { export enum ClipEmbeddingModelType {
XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
......
import { ImageType } from "../Node";
import { BaseEmbedding } from "./types"; import { BaseEmbedding } from "./types";
import { ImageType } from "./utils";
/* /*
* Base class for Multi Modal embeddings. * Base class for Multi Modal embeddings.
......
import _ from "lodash"; import _ from "lodash";
import { ImageType } from "../Node";
import { DEFAULT_SIMILARITY_TOP_K } from "../constants"; import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
import { VectorStoreQueryMode } from "../storage"; import { VectorStoreQueryMode } from "../storage";
import { SimilarityType } from "./types"; import { SimilarityType } from "./types";
...@@ -183,6 +184,7 @@ export function getTopKMMREmbeddings( ...@@ -183,6 +184,7 @@ export function getTopKMMREmbeddings(
return [resultSimilarities, resultIds]; return [resultSimilarities, resultIds];
} }
export async function readImage(input: ImageType) { export async function readImage(input: ImageType) {
const { RawImage } = await import("@xenova/transformers"); const { RawImage } = await import("@xenova/transformers");
if (input instanceof Blob) { if (input instanceof Blob) {
...@@ -193,4 +195,3 @@ export async function readImage(input: ImageType) { ...@@ -193,4 +195,3 @@ export async function readImage(input: ImageType) {
throw new Error(`Unsupported input type: ${typeof input}`); throw new Error(`Unsupported input type: ${typeof input}`);
} }
} }
export type ImageType = string | Blob | URL;
import _ from "lodash";
import { BaseNode, ImageNode, MetadataMode, TextNode } from "../../Node";
import { ClipEmbedding, MultiModalEmbedding } from "../../embeddings";
import { VectorStore } from "../../storage";
import { VectorStoreIndex } from "../vectorStore";
import { VectorIndexConstructorProps } from "../vectorStore/VectorStoreIndex";
export interface MultiModalVectorIndexConstructorProps
extends VectorIndexConstructorProps {
imageVectorStore: VectorStore;
imageEmbedModel?: MultiModalEmbedding;
}
export class MultiModalVectorStoreIndex extends VectorStoreIndex {
imageVectorStore: VectorStore;
imageEmbedModel: MultiModalEmbedding;
constructor(init: MultiModalVectorIndexConstructorProps) {
super(init);
this.imageVectorStore = init.imageVectorStore;
this.imageEmbedModel = init.imageEmbedModel ?? new ClipEmbedding();
}
/**
* Get the embeddings for image nodes.
* @param nodes
* @param serviceContext
* @param logProgress log progress to console (useful for debugging)
* @returns
*/
async getImageNodeEmbeddingResults(
nodes: ImageNode[],
logProgress: boolean = false,
) {
const isImageToText = nodes.every((node) => _.isString(node.text));
if (isImageToText) {
// image nodes have a text, use the text embedding model
return VectorStoreIndex.getNodeEmbeddingResults(
nodes,
this.serviceContext,
logProgress,
);
}
const nodesWithEmbeddings: ImageNode[] = [];
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (logProgress) {
console.log(`getting embedding for node ${i}/${nodes.length}`);
}
node.embedding = await this.imageEmbedModel.getImageEmbedding(
node.getContent(MetadataMode.EMBED),
);
nodesWithEmbeddings.push(node);
}
return nodesWithEmbeddings;
}
private splitNodes(nodes: BaseNode[]): {
imageNodes: ImageNode[];
textNodes: TextNode[];
} {
let imageNodes: ImageNode[] = [];
let textNodes: TextNode[] = [];
for (let node of nodes) {
if (node instanceof ImageNode) {
imageNodes.push(node);
}
if (node instanceof TextNode) {
textNodes.push(node);
}
}
return {
imageNodes,
textNodes,
};
}
async insertNodes(nodes: BaseNode[]): Promise<void> {
if (!nodes || nodes.length === 0) {
return;
}
const { imageNodes, textNodes } = this.splitNodes(nodes);
super.insertNodes(textNodes);
const imageNodesWithEmbedding =
await this.getImageNodeEmbeddingResults(imageNodes);
super.insertNodesToStore(this.imageVectorStore, imageNodesWithEmbedding);
}
}
...@@ -39,7 +39,7 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { ...@@ -39,7 +39,7 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
export class VectorStoreIndex extends BaseIndex<IndexDict> { export class VectorStoreIndex extends BaseIndex<IndexDict> {
vectorStore: VectorStore; vectorStore: VectorStore;
private constructor(init: VectorIndexConstructorProps) { protected constructor(init: VectorIndexConstructorProps) {
super(init); super(init);
this.vectorStore = init.vectorStore; this.vectorStore = init.vectorStore;
} }
...@@ -259,15 +259,13 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -259,15 +259,13 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
); );
} }
async insertNodes(nodes: BaseNode[]): Promise<void> { async insertNodesToStore(
const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( vectorStore: VectorStore,
nodes, nodes: BaseNode[],
this.serviceContext, ): Promise<void> {
); const newIds = await vectorStore.add(nodes);
const newIds = await this.vectorStore.add(embeddingResults);
if (!this.vectorStore.storesText) { if (!vectorStore.storesText) {
for (let i = 0; i < nodes.length; ++i) { for (let i = 0; i < nodes.length; ++i) {
this.indexStruct.addNode(nodes[i], newIds[i]); this.indexStruct.addNode(nodes[i], newIds[i]);
this.docStore.addDocuments([nodes[i]], true); this.docStore.addDocuments([nodes[i]], true);
...@@ -284,6 +282,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -284,6 +282,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
await this.storageContext.indexStore.addIndexStruct(this.indexStruct); await this.storageContext.indexStore.addIndexStruct(this.indexStruct);
} }
async insertNodes(nodes: BaseNode[]): Promise<void> {
const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults(
nodes,
this.serviceContext,
);
await this.insertNodesToStore(this.vectorStore, embeddingResults);
}
async deleteRefDoc( async deleteRefDoc(
refDocId: string, refDocId: string,
deleteFromDocStore: boolean = true, deleteFromDocStore: boolean = true,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment