diff --git a/examples/.gitignore b/examples/.gitignore index d8b83df9cdb661545c4f6d15a299e6e1f1f51dc4..8a77cb22f9809ef84223cfd2eac7b45d184a45ee 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -1 +1,2 @@ package-lock.json +storage diff --git a/examples/data/multi_modal/1.jpg b/examples/data/multi_modal/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..169024b540c591fa85e0d1c24c581dca6f8255b1 Binary files /dev/null and b/examples/data/multi_modal/1.jpg differ diff --git a/examples/data/multi_modal/2.jpg b/examples/data/multi_modal/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0a41cb1c510102bf3b8610275716bb7adac44581 Binary files /dev/null and b/examples/data/multi_modal/2.jpg differ diff --git a/examples/data/multi_modal/3.jpg b/examples/data/multi_modal/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c58d5a1fb48e4e8809804a37263bafa87c9494dc Binary files /dev/null and b/examples/data/multi_modal/3.jpg differ diff --git a/examples/data/multi_modal/60.jpg b/examples/data/multi_modal/60.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5cbdea93a792635df763dd778ec36ce9ae5cfc68 Binary files /dev/null and b/examples/data/multi_modal/60.jpg differ diff --git a/examples/data/multi_modal/61.jpg b/examples/data/multi_modal/61.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5c9edd53b61cc2b64d53cc96ab2844e34fe4831 Binary files /dev/null and b/examples/data/multi_modal/61.jpg differ diff --git a/examples/data/multi_modal/62.jpg b/examples/data/multi_modal/62.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bbb7b2199c7806ceb51ea382da77157ac7f86dc5 Binary files /dev/null and b/examples/data/multi_modal/62.jpg differ diff --git a/examples/data/multi_modal/San Francisco.txt b/examples/data/multi_modal/San Francisco.txt new file mode 100644 index 0000000000000000000000000000000000000000..938f45d2bc883b2f267b1dfed163db65e8e52f7c Binary files /dev/null and b/examples/data/multi_modal/San Francisco.txt differ diff --git a/examples/data/multi_modal/Vincent van Gogh.txt b/examples/data/multi_modal/Vincent van Gogh.txt new file mode 100644 index 0000000000000000000000000000000000000000..30b127be017a095afff17a95a5f1f93dfa917968 Binary files /dev/null and b/examples/data/multi_modal/Vincent van Gogh.txt differ diff --git a/examples/multiModal.ts b/examples/multiModal.ts new file mode 100644 index 0000000000000000000000000000000000000000..8a31980d550e30cd0c1ae90dca2ad7bda76d30b0 --- /dev/null +++ b/examples/multiModal.ts @@ -0,0 +1,48 @@ +import { + ImageNode, + serviceContextFromDefaults, + SimpleDirectoryReader, + SimpleVectorStore, + TextNode, + VectorStoreIndex, +} from "llamaindex"; +import * as path from "path"; + +async function main() { + // read data into documents + const reader = new SimpleDirectoryReader(); + const documents = await reader.loadData({ + directoryPath: "data/multi_modal", + }); + // set up vector store index with two vector stores, one for text, the other for images + const serviceContext = serviceContextFromDefaults({ chunkSize: 512 }); + const vectorStore = await SimpleVectorStore.fromPersistDir("./storage/text"); + const imageVectorStore = + await SimpleVectorStore.fromPersistDir("./storage/images"); + const index = await VectorStoreIndex.fromDocuments(documents, { + serviceContext, + imageVectorStore, + vectorStore, + }); + // retrieve documents using the index + const retriever = index.asRetriever(); + retriever.similarityTopK = 3; + const results = await retriever.retrieve( + "what are Vincent van Gogh's famous paintings", + ); + for (const result of results) { + const node = result.node; + if (!node) { + continue; + } + if (node instanceof ImageNode) { + console.log(`Image: ${path.join(__dirname, node.id_)}`); + } else if (node instanceof TextNode) { + console.log("Text:", (node as TextNode).text.substring(0, 128)); + } + console.log(`ID: ${node.id_}`); + console.log(`Similarity: ${result.score}`); + } +} + +main().catch(console.error); diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index d60f358e4bc364e52d9a15b8ee27adaac0535e12..67ed91a1c16220a732b7e2fee1a6514a8ab46b9f 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -14,6 +14,7 @@ export enum ObjectType { IMAGE = "IMAGE", INDEX = "INDEX", DOCUMENT = "DOCUMENT", + IMAGE_DOCUMENT = "IMAGE_DOCUMENT", } export enum MetadataMode { @@ -229,14 +230,6 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> { } } -// export class ImageNode extends TextNode { -// image: string = ""; - -// getType(): ObjectType { -// return ObjectType.IMAGE; -// } -// } - export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> { indexId: string = ""; @@ -285,14 +278,47 @@ export function jsonToNode(json: any, type?: ObjectType) { return new IndexNode(json); case ObjectType.DOCUMENT: return new Document(json); + case ObjectType.IMAGE_DOCUMENT: + return new ImageDocument(json); default: throw new Error(`Invalid node type: ${nodeType}`); } } -// export class ImageDocument extends Document { -// image?: string; -// } +export type ImageType = string | Blob | URL; + +export type ImageNodeConstructorProps<T extends Metadata> = Pick< + ImageNode<T>, + "image" | "id_" +> & + Partial<ImageNode<T>>; + +export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> { + image: ImageType; // image as blob + + constructor(init: ImageNodeConstructorProps<T>) { + super(init); + this.image = init.image; + } + + getType(): ObjectType { + return ObjectType.IMAGE; + } +} + +export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> { + constructor(init: ImageNodeConstructorProps<T>) { + super(init); + + if (new.target === ImageDocument) { + this.hash = this.generateHash(); + } + } + + getType() { + return ObjectType.IMAGE_DOCUMENT; + } +} /** * A node with a similarity score diff --git a/packages/core/src/NodeParser.ts b/packages/core/src/NodeParser.ts index f3d064ba5738702f4a11ba12483d6eb8122735d1..d39aae5ae98fd5ef25103850f990439e55ed94b2 100644 --- a/packages/core/src/NodeParser.ts +++ b/packages/core/src/NodeParser.ts @@ -1,4 +1,10 @@ -import { Document, NodeRelationship, TextNode } from "./Node"; +import { + BaseNode, + Document, + ImageDocument, + NodeRelationship, + TextNode, +} from "./Node"; import { SentenceSplitter } from "./TextSplitter"; import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "./constants"; @@ -27,12 +33,19 @@ export function getTextSplitsFromDocument( * @returns An array of nodes. */ export function getNodesFromDocument( - document: Document, + doc: BaseNode, textSplitter: SentenceSplitter, includeMetadata: boolean = true, includePrevNextRel: boolean = true, ) { - let nodes: TextNode[] = []; + if (doc instanceof ImageDocument) { + return [doc]; + } + if (!(doc instanceof Document)) { + throw new Error("Expected either an Image Document or Document"); + } + const document = doc as Document; + const nodes: TextNode[] = []; const textSplits = getTextSplitsFromDocument(document, textSplitter); @@ -62,7 +75,7 @@ export function getNodesFromDocument( } /** - * A NodeParser generates TextNodes from Documents + * A NodeParser generates Nodes from Documents */ export interface NodeParser { /** @@ -70,7 +83,7 @@ export interface NodeParser { * @param documents - The documents to generate nodes from. * @returns An array of nodes. */ - getNodesFromDocuments(documents: Document[]): TextNode[]; + getNodesFromDocuments(documents: BaseNode[]): BaseNode[]; } /** @@ -121,7 +134,7 @@ export class SimpleNodeParser implements NodeParser { * Generate Node objects from documents * @param documents */ - getNodesFromDocuments(documents: Document[]) { + getNodesFromDocuments(documents: BaseNode[]) { return documents .map((document) => getNodesFromDocument(document, this.textSplitter)) .flat(); diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts index b75b4b879c502b9f255b719985f3c0392e49f027..989914dc5ed2c957eed1cb461a24e7d675fc97bf 100644 --- a/packages/core/src/embeddings/ClipEmbedding.ts +++ b/packages/core/src/embeddings/ClipEmbedding.ts @@ -1,5 +1,6 @@ +import { ImageType } from "../Node"; import { MultiModalEmbedding } from "./MultiModalEmbedding"; -import { ImageType, readImage } from "./utils"; +import { readImage } from "./utils"; export enum ClipEmbeddingModelType { XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index c86ba07211ba54b157b043252324270549889502..46d68ec25948db03c0137c35094acb72f6af557d 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -1,5 +1,5 @@ +import { ImageType } from "../Node"; import { BaseEmbedding } from "./types"; -import { ImageType } from "./utils"; /* * Base class for Multi Modal embeddings. @@ -9,7 +9,6 @@ export abstract class MultiModalEmbedding extends BaseEmbedding { abstract getImageEmbedding(images: ImageType): Promise<number[]>; async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { - // Embed the input sequence of images asynchronously. return Promise.all( images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), ); diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts index cd192c3d4faa08a5b4f81c6a796de1e89233ca45..cfdacf0871096efae8c231dcfb29b7a424fedeaf 100644 --- a/packages/core/src/embeddings/utils.ts +++ b/packages/core/src/embeddings/utils.ts @@ -1,4 +1,5 @@ import _ from "lodash"; +import { ImageType } from "../Node"; import { DEFAULT_SIMILARITY_TOP_K } from "../constants"; import { VectorStoreQueryMode } from "../storage"; import { SimilarityType } from "./types"; @@ -183,6 +184,7 @@ export function getTopKMMREmbeddings( return [resultSimilarities, resultIds]; } + export async function readImage(input: ImageType) { const { RawImage } = await import("@xenova/transformers"); if (input instanceof Blob) { @@ -193,4 +195,3 @@ export async function readImage(input: ImageType) { throw new Error(`Unsupported input type: ${typeof input}`); } } -export type ImageType = string | Blob | URL; diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index 1cfa454e1482e1e8268e9fb1e0ca469e6dfcd551..b24b732ccf17d06b2ba926d4cecd8ab9f8543d05 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -1,12 +1,14 @@ import { Event } from "../../callbacks/CallbackManager"; import { DEFAULT_SIMILARITY_TOP_K } from "../../constants"; +import { BaseEmbedding } from "../../embeddings"; import { globalsHelper } from "../../GlobalsHelper"; -import { NodeWithScore } from "../../Node"; +import { Metadata, NodeWithScore } from "../../Node"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext } from "../../ServiceContext"; import { VectorStoreQuery, VectorStoreQueryMode, + VectorStoreQueryResult, } from "../../storage/vectorStore/types"; import { VectorStoreIndex } from "./VectorStoreIndex"; @@ -37,16 +39,67 @@ export class VectorIndexRetriever implements BaseRetriever { parentEvent?: Event, preFilters?: unknown, ): Promise<NodeWithScore[]> { - const queryEmbedding = - await this.serviceContext.embedModel.getQueryEmbedding(query); + let nodesWithScores = await this.textRetrieve(query, preFilters); + nodesWithScores = nodesWithScores.concat( + await this.textToImageRetrieve(query, preFilters), + ); + this.sendEvent(query, nodesWithScores, parentEvent); + return nodesWithScores; + } + + protected async textRetrieve( + query: string, + preFilters?: unknown, + ): Promise<NodeWithScore[]> { + const q = await this.buildVectorStoreQuery(this.index.embedModel, query); + const result = await this.index.vectorStore.query(q, preFilters); + return this.buildNodeListFromQueryResult(result); + } + + private async textToImageRetrieve(query: string, preFilters?: unknown) { + if (!this.index.imageEmbedModel || !this.index.imageVectorStore) { + // no-op if image embedding and vector store are not set + return []; + } + const q = await this.buildVectorStoreQuery( + this.index.imageEmbedModel, + query, + ); + const result = await this.index.imageVectorStore.query(q, preFilters); + return this.buildNodeListFromQueryResult(result); + } - const q: VectorStoreQuery = { + protected sendEvent( + query: string, + nodesWithScores: NodeWithScore<Metadata>[], + parentEvent: Event | undefined, + ) { + if (this.serviceContext.callbackManager.onRetrieve) { + this.serviceContext.callbackManager.onRetrieve({ + query, + nodes: nodesWithScores, + event: globalsHelper.createEvent({ + parentEvent, + type: "retrieve", + }), + }); + } + } + + protected async buildVectorStoreQuery( + embedModel: BaseEmbedding, + query: string, + ): Promise<VectorStoreQuery> { + const queryEmbedding = await embedModel.getQueryEmbedding(query); + + return { queryEmbedding: queryEmbedding, mode: VectorStoreQueryMode.DEFAULT, similarityTopK: this.similarityTopK, }; - const result = await this.index.vectorStore.query(q, preFilters); + } + protected buildNodeListFromQueryResult(result: VectorStoreQueryResult) { let nodesWithScores: NodeWithScore[] = []; for (let i = 0; i < result.ids.length; i++) { const nodeFromResult = result.nodes?.[i]; @@ -61,17 +114,6 @@ export class VectorIndexRetriever implements BaseRetriever { }); } - if (this.serviceContext.callbackManager.onRetrieve) { - this.serviceContext.callbackManager.onRetrieve({ - query, - nodes: nodesWithScores, - event: globalsHelper.createEvent({ - parentEvent, - type: "retrieve", - }), - }); - } - return nodesWithScores; } diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index fe1c0d95d6ab759d14d3ad55379fd71635039e04..6721aaa862da17a9bc3019cec421b2f6c8218300 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -1,4 +1,12 @@ -import { BaseNode, Document, MetadataMode } from "../../Node"; +import { + BaseNode, + Document, + ImageNode, + MetadataMode, + ObjectType, + TextNode, + jsonToNode, +} from "../../Node"; import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; import { ResponseSynthesizer } from "../../ResponseSynthesizer"; import { BaseRetriever } from "../../Retriever"; @@ -6,11 +14,16 @@ import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; -import { BaseDocumentStore } from "../../storage/docStore/types"; +import { + BaseEmbedding, + ClipEmbedding, + MultiModalEmbedding, +} from "../../embeddings"; import { StorageContext, storageContextFromDefaults, } from "../../storage/StorageContext"; +import { BaseIndexStore } from "../../storage/indexStore/types"; import { VectorStore } from "../../storage/vectorStore/types"; import { BaseIndex, @@ -21,16 +34,21 @@ import { import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { VectorIndexRetriever } from "./VectorIndexRetriever"; -export interface VectorIndexOptions { - nodes?: BaseNode[]; +interface IndexStructOptions { indexStruct?: IndexDict; indexId?: string; +} +export interface VectorIndexOptions extends IndexStructOptions { + nodes?: BaseNode[]; serviceContext?: ServiceContext; storageContext?: StorageContext; + imageVectorStore?: VectorStore; + vectorStore?: VectorStore; } export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { - vectorStore: VectorStore; + indexStore: BaseIndexStore; + imageVectorStore?: VectorStore; } /** @@ -38,15 +56,24 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { */ export class VectorStoreIndex extends BaseIndex<IndexDict> { vectorStore: VectorStore; + indexStore: BaseIndexStore; + embedModel: BaseEmbedding; + imageVectorStore?: VectorStore; + imageEmbedModel?: MultiModalEmbedding; private constructor(init: VectorIndexConstructorProps) { super(init); - this.vectorStore = init.vectorStore; + this.indexStore = init.indexStore; + this.vectorStore = init.vectorStore ?? init.storageContext.vectorStore; + this.embedModel = init.serviceContext.embedModel; + this.imageVectorStore = init.imageVectorStore; + if (this.imageVectorStore) { + this.imageEmbedModel = new ClipEmbedding(); + } } /** - * The async init function should be called after the constructor. - * This is needed to handle persistence. + * The async init function creates a new VectorStoreIndex. * @param options * @returns */ @@ -55,11 +82,43 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { options.storageContext ?? (await storageContextFromDefaults({})); const serviceContext = options.serviceContext ?? serviceContextFromDefaults({}); - const docStore = storageContext.docStore; - const vectorStore = storageContext.vectorStore; const indexStore = storageContext.indexStore; + const docStore = storageContext.docStore; + + let indexStruct = await VectorStoreIndex.setupIndexStructFromStorage( + indexStore, + options, + ); + + if (!options.nodes && !indexStruct) { + throw new Error( + "Cannot initialize VectorStoreIndex without nodes or indexStruct", + ); + } + + indexStruct = indexStruct ?? new IndexDict(); + + const index = new this({ + storageContext, + serviceContext, + docStore, + indexStruct, + indexStore, + vectorStore: options.vectorStore, + imageVectorStore: options.imageVectorStore, + }); - // Setup IndexStruct from storage + if (options.nodes) { + // If nodes are passed in, then we need to update the index + await index.buildIndexFromNodes(options.nodes); + } + return index; + } + + private static async setupIndexStructFromStorage( + indexStore: BaseIndexStore, + options: IndexStructOptions, + ) { let indexStructs = (await indexStore.getIndexStructs()) as IndexDict[]; let indexStruct: IndexDict | undefined; @@ -77,55 +136,23 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { indexStruct = (await indexStore.getIndexStruct( options.indexId, )) as IndexDict; - } else { - indexStruct = undefined; } - - // check indexStruct type + // Check indexStruct type if (indexStruct && indexStruct.type !== IndexStructType.SIMPLE_DICT) { throw new Error( "Attempting to initialize VectorStoreIndex with non-vector indexStruct", ); } - - if (options.nodes) { - // If nodes are passed in, then we need to update the index - indexStruct = await VectorStoreIndex.buildIndexFromNodes( - options.nodes, - serviceContext, - vectorStore, - docStore, - indexStruct, - ); - - await indexStore.addIndexStruct(indexStruct); - } else if (!indexStruct) { - throw new Error( - "Cannot initialize VectorStoreIndex without nodes or indexStruct", - ); - } - - return new VectorStoreIndex({ - storageContext, - serviceContext, - docStore, - vectorStore, - indexStruct, - }); + return indexStruct; } /** * Get the embeddings for nodes. * @param nodes - * @param serviceContext * @param logProgress log progress to console (useful for debugging) * @returns */ - static async getNodeEmbeddingResults( - nodes: BaseNode[], - serviceContext: ServiceContext, - logProgress = false, - ) { + async getNodeEmbeddingResults(nodes: BaseNode[], logProgress = false) { const nodesWithEmbeddings: BaseNode[] = []; for (let i = 0; i < nodes.length; ++i) { @@ -133,7 +160,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { if (logProgress) { console.log(`getting embedding for node ${i}/${nodes.length}`); } - const embedding = await serviceContext.embedModel.getTextEmbedding( + const embedding = await this.embedModel.getTextEmbedding( node.getContent(MetadataMode.EMBED), ); node.embedding = embedding; @@ -146,77 +173,47 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { /** * Get embeddings for nodes and place them into the index. * @param nodes - * @param serviceContext - * @param vectorStore * @returns */ - static async buildIndexFromNodes( - nodes: BaseNode[], - serviceContext: ServiceContext, - vectorStore: VectorStore, - docStore: BaseDocumentStore, - indexDict?: IndexDict, - ): Promise<IndexDict> { - indexDict = indexDict ?? new IndexDict(); - + async buildIndexFromNodes(nodes: BaseNode[]) { // Check if the index already has nodes with the same hash const newNodes = nodes.filter((node) => - Object.entries(indexDict!.nodesDict).reduce((acc, [key, value]) => { - if (value.hash === node.hash) { - acc = false; - } - return acc; - }, true), + Object.entries(this.indexStruct!.nodesDict).reduce( + (acc, [key, value]) => { + if (value.hash === node.hash) { + acc = false; + } + return acc; + }, + true, + ), ); - const embeddingResults = await this.getNodeEmbeddingResults( - newNodes, - serviceContext, - ); - - await vectorStore.add(embeddingResults); - - if (!vectorStore.storesText) { - await docStore.addDocuments(embeddingResults, true); - } - - for (const node of embeddingResults) { - indexDict.addNode(node); - } - - return indexDict; + await this.insertNodes(newNodes); } /** * High level API: split documents, get embeddings, and build index. * @param documents - * @param storageContext - * @param serviceContext + * @param args * @returns */ static async fromDocuments( documents: Document[], - args: { - storageContext?: StorageContext; - serviceContext?: ServiceContext; - } = {}, + args: VectorIndexOptions = {}, ): Promise<VectorStoreIndex> { - let { storageContext, serviceContext } = args; - storageContext = storageContext ?? (await storageContextFromDefaults({})); - serviceContext = serviceContext ?? serviceContextFromDefaults({}); - const docStore = storageContext.docStore; + args.storageContext = + args.storageContext ?? (await storageContextFromDefaults({})); + args.serviceContext = args.serviceContext ?? serviceContextFromDefaults({}); + const docStore = args.storageContext.docStore; for (const doc of documents) { docStore.setDocumentHash(doc.id_, doc.hash); } - const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents); - const index = await VectorStoreIndex.init({ - nodes, - storageContext, - serviceContext, - }); - return index; + args.nodes = + args.serviceContext.nodeParser.getNodesFromDocuments(documents); + return await this.init(args); } static async fromVectorStore( @@ -231,7 +228,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { const storageContext = await storageContextFromDefaults({ vectorStore }); - const index = await VectorStoreIndex.init({ + const index = await this.init({ nodes: [], storageContext, serviceContext, @@ -259,51 +256,131 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ); } - async insertNodes(nodes: BaseNode[]): Promise<void> { - const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( - nodes, - this.serviceContext, - ); - - const newIds = await this.vectorStore.add(embeddingResults); + protected async insertNodesToStore( + vectorStore: VectorStore, + nodes: BaseNode[], + ): Promise<void> { + const newIds = await vectorStore.add(nodes); - if (!this.vectorStore.storesText) { - for (let i = 0; i < nodes.length; ++i) { - this.indexStruct.addNode(nodes[i], newIds[i]); - this.docStore.addDocuments([nodes[i]], true); - } - } else { - for (let i = 0; i < nodes.length; ++i) { - if (nodes[i].getType() === "INDEX") { - this.indexStruct.addNode(nodes[i], newIds[i]); - this.docStore.addDocuments([nodes[i]], true); - } + // NOTE: if the vector store doesn't store text, + // we need to add the nodes to the index struct and document store + // NOTE: if the vector store keeps text, + // we only need to add image and index nodes + for (let i = 0; i < nodes.length; ++i) { + const type = nodes[i].getType(); + if ( + !vectorStore.storesText || + type === ObjectType.INDEX || + type === ObjectType.IMAGE + ) { + const nodeWithoutEmbedding = jsonToNode(nodes[i].toJSON()); + nodeWithoutEmbedding.embedding = undefined; + this.indexStruct.addNode(nodeWithoutEmbedding, newIds[i]); + this.docStore.addDocuments([nodeWithoutEmbedding], true); } } + } - await this.storageContext.indexStore.addIndexStruct(this.indexStruct); + async insertNodes(nodes: BaseNode[]): Promise<void> { + if (!nodes || nodes.length === 0) { + return; + } + const { imageNodes, textNodes } = this.splitNodes(nodes); + if (imageNodes.length > 0) { + if (!this.imageVectorStore) { + throw new Error("Cannot insert image nodes without image vector store"); + } + const imageNodesWithEmbedding = + await this.getImageNodeEmbeddingResults(imageNodes); + await this.insertNodesToStore( + this.imageVectorStore, + imageNodesWithEmbedding, + ); + } + const embeddingResults = await this.getNodeEmbeddingResults(textNodes); + await this.insertNodesToStore(this.vectorStore, embeddingResults); + await this.indexStore.addIndexStruct(this.indexStruct); } async deleteRefDoc( refDocId: string, deleteFromDocStore: boolean = true, ): Promise<void> { - this.vectorStore.delete(refDocId); + await this.deleteRefDocFromStore(this.vectorStore, refDocId); + if (this.imageVectorStore) { + await this.deleteRefDocFromStore(this.imageVectorStore, refDocId); + } + + if (deleteFromDocStore) { + await this.docStore.deleteDocument(refDocId, false); + } + } + + protected async deleteRefDocFromStore( + vectorStore: VectorStore, + refDocId: string, + ): Promise<void> { + vectorStore.delete(refDocId); - if (!this.vectorStore.storesText) { + if (!vectorStore.storesText) { const refDocInfo = await this.docStore.getRefDocInfo(refDocId); if (refDocInfo) { for (const nodeId of refDocInfo.nodeIds) { this.indexStruct.delete(nodeId); + vectorStore.delete(nodeId); } } + await this.indexStore.addIndexStruct(this.indexStruct); + } + } + + /** + * Get the embeddings for image nodes. + * @param nodes + * @param serviceContext + * @param logProgress log progress to console (useful for debugging) + * @returns + */ + async getImageNodeEmbeddingResults( + nodes: ImageNode[], + logProgress: boolean = false, + ): Promise<BaseNode[]> { + if (!this.imageEmbedModel) { + return []; + } - await this.storageContext.indexStore.addIndexStruct(this.indexStruct); + const nodesWithEmbeddings: ImageNode[] = []; + + for (let i = 0; i < nodes.length; ++i) { + const node = nodes[i]; + if (logProgress) { + console.log(`getting embedding for node ${i}/${nodes.length}`); + } + node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image); + nodesWithEmbeddings.push(node); } - if (deleteFromDocStore) { - await this.docStore.deleteDocument(refDocId, false); + return nodesWithEmbeddings; + } + + private splitNodes(nodes: BaseNode[]): { + imageNodes: ImageNode[]; + textNodes: TextNode[]; + } { + let imageNodes: ImageNode[] = []; + let textNodes: TextNode[] = []; + + for (let node of nodes) { + if (node instanceof ImageNode) { + imageNodes.push(node); + } else if (node instanceof TextNode) { + textNodes.push(node); + } } + return { + imageNodes, + textNodes, + }; } } diff --git a/packages/core/src/readers/ImageReader.ts b/packages/core/src/readers/ImageReader.ts new file mode 100644 index 0000000000000000000000000000000000000000..fd1b3969558b7076a3ceaad960eb07f9fe86f42e --- /dev/null +++ b/packages/core/src/readers/ImageReader.ts @@ -0,0 +1,25 @@ +import { Document, ImageDocument } from "../Node"; +import { DEFAULT_FS } from "../storage/constants"; +import { GenericFileSystem } from "../storage/FileSystem"; +import { BaseReader } from "./base"; + +/** + * Reads the content of an image file into a Document object (which stores the image file as a Blob). + */ +export class ImageReader implements BaseReader { + /** + * Public method for this reader. + * Required by BaseReader interface. + * @param file Path/name of the file to be loaded. + * @param fs fs wrapper interface for getting the file content. + * @returns Promise<Document[]> A Promise object, eventually yielding zero or one ImageDocument of the specified file. + */ + async loadData( + file: string, + fs: GenericFileSystem = DEFAULT_FS, + ): Promise<Document[]> { + const dataBuffer = await fs.readFile(file); + const blob = new Blob([dataBuffer]); + return [new ImageDocument({ image: blob, id_: file })]; + } +} diff --git a/packages/core/src/readers/SimpleDirectoryReader.ts b/packages/core/src/readers/SimpleDirectoryReader.ts index 9a2e30aa1698b8f81524bfd9849be529ca0370d6..5146256ae0251266429b34e58032a8068a8bcaed 100644 --- a/packages/core/src/readers/SimpleDirectoryReader.ts +++ b/packages/core/src/readers/SimpleDirectoryReader.ts @@ -5,6 +5,7 @@ import { DEFAULT_FS } from "../storage/constants"; import { PapaCSVReader } from "./CSVReader"; import { DocxReader } from "./DocxReader"; import { HTMLReader } from "./HTMLReader"; +import { ImageReader } from "./ImageReader"; import { MarkdownReader } from "./MarkdownReader"; import { PDFReader } from "./PDFReader"; import { BaseReader } from "./base"; @@ -42,6 +43,10 @@ export const FILE_EXT_TO_READER: Record<string, BaseReader> = { docx: new DocxReader(), htm: new HTMLReader(), html: new HTMLReader(), + jpg: new ImageReader(), + jpeg: new ImageReader(), + png: new ImageReader(), + gif: new ImageReader(), }; export type SimpleDirectoryReaderLoadDataProps = { @@ -54,7 +59,7 @@ export type SimpleDirectoryReaderLoadDataProps = { /** * Read all of the documents in a directory. * By default, supports the list of file types - * in the FILE_EXIT_TO_READER map. + * in the FILE_EXT_TO_READER map. */ export class SimpleDirectoryReader implements BaseReader { constructor(private observer?: ReaderCallback) {} diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index 929ebe2c24b3cf9b7cc074902280d89785651b0d..e5242c0ffa913a831d5fd90d0ca42a626765647b 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -164,7 +164,7 @@ export class SimpleVectorStore implements VectorStore { let dirPath = path.dirname(persistPath); if (!(await exists(fs, dirPath))) { - await fs.mkdir(dirPath); + await fs.mkdir(dirPath, { recursive: true }); } let dataDict: any = {};