diff --git a/.changeset/dry-queens-dream.md b/.changeset/dry-queens-dream.md new file mode 100644 index 0000000000000000000000000000000000000000..6ef8c01b3c46022d2bda974ac9d2efa8f0f4cb93 --- /dev/null +++ b/.changeset/dry-queens-dream.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add vectorStores to storage context to define vector store per modality diff --git a/examples/multimodal/load.ts b/examples/multimodal/load.ts index 3ed94e30b246c561d8d51b13073b038f1ba8be9a..e2d0fa7a44172f141cf456f6cf4fa827531856e2 100644 --- a/examples/multimodal/load.ts +++ b/examples/multimodal/load.ts @@ -1,11 +1,6 @@ -import { - Settings, - SimpleDirectoryReader, - VectorStoreIndex, - storageContextFromDefaults, -} from "llamaindex"; - -import * as path from "path"; +import { Settings, SimpleDirectoryReader, VectorStoreIndex } from "llamaindex"; +import path from "path"; +import { getStorageContext } from "./storage"; // Update chunk size and overlap Settings.chunkSize = 512; @@ -25,10 +20,7 @@ async function generateDatasource() { const documents = await new SimpleDirectoryReader().loadData({ directoryPath: path.join("multimodal", "data"), }); - const storageContext = await storageContextFromDefaults({ - persistDir: "storage", - storeImages: true, - }); + const storageContext = await getStorageContext(); await VectorStoreIndex.fromDocuments(documents, { storageContext, }); diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts index 0213de3cc5f1cbdc072d5531c3d209dbfd72d350..402ed5575991deb831835303c4d176863be390cd 100644 --- a/examples/multimodal/rag.ts +++ b/examples/multimodal/rag.ts @@ -1,12 +1,12 @@ import { - CallbackManager, ImageType, MultiModalResponseSynthesizer, OpenAI, + RetrievalEndEvent, Settings, VectorStoreIndex, - storageContextFromDefaults, } from "llamaindex"; +import { getStorageContext } from "./storage"; // Update chunk size and overlap Settings.chunkSize = 512; @@ -16,32 +16,23 @@ Settings.chunkOverlap = 20; Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 }); // Update callbackManager -Settings.callbackManager = new CallbackManager({ - onRetrieve: ({ query, nodes }) => { - console.log(`Retrieved ${nodes.length} nodes for query: ${query}`); - }, +Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => { + const { nodes, query } = event.detail.payload; + console.log(`Retrieved ${nodes.length} nodes for query: ${query}`); }); -export async function createIndex() { - // set up vector store index with two vector stores, one for text, the other for images - const storageContext = await storageContextFromDefaults({ - persistDir: "storage", - storeImages: true, - }); - return await VectorStoreIndex.init({ - nodes: [], - storageContext, - }); -} - async function main() { const images: ImageType[] = []; - const index = await createIndex(); + const storageContext = await getStorageContext(); + const index = await VectorStoreIndex.init({ + nodes: [], + storageContext, + }); const queryEngine = index.asQueryEngine({ responseSynthesizer: new MultiModalResponseSynthesizer(), - retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }), + retriever: index.asRetriever({ topK: { TEXT: 3, IMAGE: 1 } }), }); const result = await queryEngine.query({ query: "Tell me more about Vincent van Gogh's famous paintings", diff --git a/examples/multimodal/retrieve.ts b/examples/multimodal/retrieve.ts index 7c5bf2f85d5d3ae4a5f4617d92253e872805eb80..beffdb22ef8603f87b2ebf1162feaed6823261e0 100644 --- a/examples/multimodal/retrieve.ts +++ b/examples/multimodal/retrieve.ts @@ -1,31 +1,18 @@ -import { - ImageNode, - Settings, - TextNode, - VectorStoreIndex, - storageContextFromDefaults, -} from "llamaindex"; +import { ImageNode, Settings, TextNode, VectorStoreIndex } from "llamaindex"; +import { getStorageContext } from "./storage"; // Update chunk size and overlap Settings.chunkSize = 512; Settings.chunkOverlap = 20; -export async function createIndex() { - // set up vector store index with two vector stores, one for text, the other for images - const storageContext = await storageContextFromDefaults({ - persistDir: "storage", - storeImages: true, - }); - return await VectorStoreIndex.init({ +async function main() { + // retrieve documents using the index + const storageContext = await getStorageContext(); + const index = await VectorStoreIndex.init({ nodes: [], storageContext, }); -} - -async function main() { - // retrieve documents using the index - const index = await createIndex(); - const retriever = index.asRetriever({ similarityTopK: 3 }); + const retriever = index.asRetriever({ topK: { TEXT: 1, IMAGE: 3 } }); const results = await retriever.retrieve({ query: "what are Vincent van Gogh's famous paintings", }); @@ -40,7 +27,7 @@ async function main() { console.log("Text:", (node as TextNode).text.substring(0, 128)); } console.log(`ID: ${node.id_}`); - console.log(`Similarity: ${result.score}`); + console.log(`Similarity: ${result.score}\n`); } } diff --git a/examples/multimodal/storage.ts b/examples/multimodal/storage.ts new file mode 100644 index 0000000000000000000000000000000000000000..7a0ba621a2c1ea587fd3edd0e370be3bc7da8486 --- /dev/null +++ b/examples/multimodal/storage.ts @@ -0,0 +1,17 @@ +import { storageContextFromDefaults } from "llamaindex"; + +// set up store context with two vector stores, one for text, the other for images +export async function getStorageContext() { + return await storageContextFromDefaults({ + persistDir: "storage", + storeImages: true, + // if storeImages is true, the following vector store will be added + // vectorStores: { + // IMAGE: SimpleVectorStore.fromPersistDir( + // `${persistDir}/images`, + // fs, + // new ClipEmbedding(), + // ), + // }, + }); +} diff --git a/examples/vectorIndexFromVectorStore.ts b/examples/vectorIndexFromVectorStore.ts index dde1edddf35e9124aef620b2edce1d8ec7cb4e13..cf6c672a90fece3c5e0b70c98447661c5706faf5 100644 --- a/examples/vectorIndexFromVectorStore.ts +++ b/examples/vectorIndexFromVectorStore.ts @@ -1,5 +1,6 @@ import { OpenAI, + OpenAIEmbedding, ResponseSynthesizer, RetrieverQueryEngine, Settings, @@ -28,6 +29,7 @@ class PineconeVectorStore<T extends RecordMetadata = RecordMetadata> { storesText = true; isEmbeddingQuery = false; + embedModel = new OpenAIEmbedding(); indexName!: string; pineconeClient!: Pinecone; diff --git a/packages/core/package.json b/packages/core/package.json index 19da579b16e37f9de45c8b030811bdf689e186c2..97568cd3ba06840861eb214c25a76e504bc67e9f 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -151,7 +151,7 @@ "build:type": "tsc -p tsconfig.json", "copy": "cp -r ../../README.md ../../LICENSE .", "postbuild": "pnpm run copy && node -e \"require('fs').writeFileSync('./dist/cjs/package.json', JSON.stringify({ type: 'commonjs' }))\"", - "circular-check": "madge -c ./src/index.ts", + "circular-check": "madge -c ./src/**/*.ts", "dev": "concurrently \"pnpm run build:esm --watch\" \"pnpm run build:cjs --watch\" \"pnpm run build:type --watch\"" } } diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index 820f464df7a8f8d08bca851b9a19dc98242d1825..389fcdd71a471420d00fce3b4f470aff104eccf2 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -417,22 +417,32 @@ export interface NodeWithScore<T extends Metadata = Metadata> { score?: number; } -export function splitNodesByType(nodes: BaseNode[]): { - imageNodes: ImageNode[]; - textNodes: TextNode[]; -} { - const imageNodes: ImageNode[] = []; - const textNodes: TextNode[] = []; +export enum ModalityType { + TEXT = "TEXT", + IMAGE = "IMAGE", +} + +type NodesByType = { + [P in ModalityType]?: BaseNode[]; +}; + +export function splitNodesByType(nodes: BaseNode[]): NodesByType { + const result: NodesByType = {}; for (const node of nodes) { + let type: ModalityType; if (node instanceof ImageNode) { - imageNodes.push(node); + type = ModalityType.IMAGE; } else if (node instanceof TextNode) { - textNodes.push(node); + type = ModalityType.TEXT; + } else { + throw new Error(`Unknown node type: ${node.type}`); + } + if (type in result) { + result[type]?.push(node); + } else { + result[type] = [node]; } } - return { - imageNodes, - textNodes, - }; + return result; } diff --git a/packages/core/src/Settings.ts b/packages/core/src/Settings.ts index 7b2d751ccfe2621ce289dc6f4cccc35c07550723..16481a211c667428fd4ec7276c618b5e70224c9d 100644 --- a/packages/core/src/Settings.ts +++ b/packages/core/src/Settings.ts @@ -1,5 +1,4 @@ import { CallbackManager } from "./callbacks/CallbackManager.js"; -import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js"; import { OpenAI } from "./llm/openai.js"; import { PromptHelper } from "./PromptHelper.js"; @@ -13,6 +12,11 @@ import { setCallbackManager, withCallbackManager, } from "./internal/settings/CallbackManager.js"; +import { + getEmbeddedModel, + setEmbeddedModel, + withEmbeddedModel, +} from "./internal/settings/EmbedModel.js"; import { getChunkSize, setChunkSize, @@ -44,13 +48,11 @@ class GlobalSettings implements Config { #prompt: PromptConfig = {}; #llm: LLM | null = null; #promptHelper: PromptHelper | null = null; - #embedModel: BaseEmbedding | null = null; #nodeParser: NodeParser | null = null; #chunkOverlap?: number; #llmAsyncLocalStorage = new AsyncLocalStorage<LLM>(); #promptHelperAsyncLocalStorage = new AsyncLocalStorage<PromptHelper>(); - #embedModelAsyncLocalStorage = new AsyncLocalStorage<BaseEmbedding>(); #nodeParserAsyncLocalStorage = new AsyncLocalStorage<NodeParser>(); #chunkOverlapAsyncLocalStorage = new AsyncLocalStorage<number>(); #promptAsyncLocalStorage = new AsyncLocalStorage<PromptConfig>(); @@ -100,19 +102,15 @@ class GlobalSettings implements Config { } get embedModel(): BaseEmbedding { - if (this.#embedModel === null) { - this.#embedModel = new OpenAIEmbedding(); - } - - return this.#embedModelAsyncLocalStorage.getStore() ?? this.#embedModel; + return getEmbeddedModel(); } set embedModel(embedModel: BaseEmbedding) { - this.#embedModel = embedModel; + setEmbeddedModel(embedModel); } withEmbedModel<Result>(embedModel: BaseEmbedding, fn: () => Result): Result { - return this.#embedModelAsyncLocalStorage.run(embedModel, fn); + return withEmbeddedModel(embedModel, fn); } get nodeParser(): NodeParser { diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index e220eede0ac7c1ba0a148cc0dfbcf078c4338bad..0dbdd5390535011534b02b65db69ac3cc165b9cc 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -1,5 +1,7 @@ import { + ImageNode, MetadataMode, + ModalityType, splitNodesByType, type BaseNode, type ImageType, @@ -24,7 +26,9 @@ export abstract class MultiModalEmbedding extends BaseEmbedding { } async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { - const { imageNodes, textNodes } = splitNodesByType(nodes); + const nodeMap = splitNodesByType(nodes); + const imageNodes = nodeMap[ModalityType.IMAGE] ?? []; + const textNodes = nodeMap[ModalityType.TEXT] ?? []; const embeddings = await batchEmbeddings( textNodes.map((node) => node.getContent(MetadataMode.EMBED)), @@ -37,7 +41,7 @@ export abstract class MultiModalEmbedding extends BaseEmbedding { } const imageEmbeddings = await batchEmbeddings( - imageNodes.map((n) => n.image), + imageNodes.map((n) => (n as ImageNode).image), this.getImageEmbeddings.bind(this), this.embedBatchSize, _options, diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts index 80fa61e822348d471d1563c39d510f254a4f5773..d62677df3ae29249a3ccf91144d95fe5cfa24583 100644 --- a/packages/core/src/embeddings/utils.ts +++ b/packages/core/src/embeddings/utils.ts @@ -3,7 +3,7 @@ import _ from "lodash"; import { filetypemime } from "magic-bytes.js"; import type { ImageType } from "../Node.js"; import { DEFAULT_SIMILARITY_TOP_K } from "../constants.js"; -import { VectorStoreQueryMode } from "../storage/vectorStore/types.js"; +import type { VectorStoreQueryMode } from "../storage/vectorStore/types.js"; /** * Similarity type @@ -126,11 +126,9 @@ export function getTopKEmbeddingsLearner( embeddings: number[][], similarityTopK?: number, embeddingsIds?: any[], - queryMode: VectorStoreQueryMode = VectorStoreQueryMode.SVM, + queryMode?: VectorStoreQueryMode, ): [number[], any[]] { throw new Error("Not implemented yet"); - // To support SVM properly we're probably going to have to use something like - // https://github.com/mljs/libsvm which itself hasn't been updated in a while } // eslint-disable-next-line max-params diff --git a/packages/core/src/indices/BaseIndex.ts b/packages/core/src/indices/BaseIndex.ts index 69f9a79b90e180569bd368af66a0cec602047047..cd199a7d73898ac58c1d1fcc9d798e6d49f4a843 100644 --- a/packages/core/src/indices/BaseIndex.ts +++ b/packages/core/src/indices/BaseIndex.ts @@ -6,7 +6,6 @@ import { runTransformations } from "../ingestion/IngestionPipeline.js"; import type { StorageContext } from "../storage/StorageContext.js"; import type { BaseDocumentStore } from "../storage/docStore/types.js"; import type { BaseIndexStore } from "../storage/indexStore/types.js"; -import type { VectorStore } from "../storage/vectorStore/types.js"; import type { BaseSynthesizer } from "../synthesizers/types.js"; import type { QueryEngine } from "../types.js"; import { IndexStruct } from "./IndexStruct.js"; @@ -47,7 +46,6 @@ export interface BaseIndexInit<T> { serviceContext?: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore?: VectorStore; indexStore?: BaseIndexStore; indexStruct: T; } @@ -60,7 +58,6 @@ export abstract class BaseIndex<T> { serviceContext?: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore?: VectorStore; indexStore?: BaseIndexStore; indexStruct: T; @@ -68,7 +65,6 @@ export abstract class BaseIndex<T> { this.serviceContext = init.serviceContext; this.storageContext = init.storageContext; this.docStore = init.docStore; - this.vectorStore = init.vectorStore; this.indexStore = init.indexStore; this.indexStruct = init.indexStruct; } diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts index d96cb3a05fd571cefec81f182f23a89b1d8a9dec..f185cac2f5f1de0e5afac1b588a5f90b04e38327 100644 --- a/packages/core/src/indices/vectorStore/index.ts +++ b/packages/core/src/indices/vectorStore/index.ts @@ -1,19 +1,22 @@ -import type { BaseNode, Document, NodeWithScore } from "../../Node.js"; -import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js"; +import { + ImageNode, + ModalityType, + ObjectType, + splitNodesByType, + type BaseNode, + type Document, + type NodeWithScore, +} from "../../Node.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; -import { - embedModelFromSettingsOrContext, - nodeParserFromSettingsOrContext, -} from "../../Settings.js"; +import { nodeParserFromSettingsOrContext } from "../../Settings.js"; import { DEFAULT_SIMILARITY_TOP_K } from "../../constants.js"; -import { ClipEmbedding } from "../../embeddings/ClipEmbedding.js"; -import type { - BaseEmbedding, - MultiModalEmbedding, -} from "../../embeddings/index.js"; +import type { BaseEmbedding } from "../../embeddings/index.js"; import { RetrieverQueryEngine } from "../../engines/query/RetrieverQueryEngine.js"; -import { runTransformations } from "../../ingestion/IngestionPipeline.js"; +import { + addNodesToVectorStores, + runTransformations, +} from "../../ingestion/IngestionPipeline.js"; import { DocStoreStrategy, createDocStoreStrategy, @@ -26,6 +29,7 @@ import { storageContextFromDefaults } from "../../storage/StorageContext.js"; import type { MetadataFilters, VectorStore, + VectorStoreByType, VectorStoreQuery, VectorStoreQueryResult, } from "../../storage/index.js"; @@ -45,36 +49,28 @@ export interface VectorIndexOptions extends IndexStructOptions { nodes?: BaseNode[]; serviceContext?: ServiceContext; storageContext?: StorageContext; - imageVectorStore?: VectorStore; - vectorStore?: VectorStore; + vectorStores?: VectorStoreByType; logProgress?: boolean; } export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { indexStore: BaseIndexStore; - imageVectorStore?: VectorStore; + vectorStores?: VectorStoreByType; } /** - * The VectorStoreIndex, an index that stores the nodes only according to their vector embedings. + * The VectorStoreIndex, an index that stores the nodes only according to their vector embeddings. */ export class VectorStoreIndex extends BaseIndex<IndexDict> { - vectorStore: VectorStore; indexStore: BaseIndexStore; - embedModel: BaseEmbedding; - imageVectorStore?: VectorStore; - imageEmbedModel?: MultiModalEmbedding; + embedModel?: BaseEmbedding; + vectorStores: VectorStoreByType; private constructor(init: VectorIndexConstructorProps) { super(init); this.indexStore = init.indexStore; - this.vectorStore = init.vectorStore ?? init.storageContext.vectorStore; - this.embedModel = embedModelFromSettingsOrContext(init.serviceContext); - this.imageVectorStore = - init.imageVectorStore ?? init.storageContext.imageVectorStore; - if (this.imageVectorStore) { - this.imageEmbedModel = new ClipEmbedding(); - } + this.vectorStores = init.vectorStores ?? init.storageContext.vectorStores; + this.embedModel = init.serviceContext?.embedModel; } /** @@ -110,8 +106,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { docStore, indexStruct, indexStore, - vectorStore: options.vectorStore, - imageVectorStore: options.imageVectorStore, + vectorStores: options.vectorStores, }); if (options.nodes) { @@ -169,20 +164,17 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { nodes: BaseNode[], options?: { logProgress?: boolean }, ): Promise<BaseNode[]> { - const { imageNodes, textNodes } = splitNodesByType(nodes); - if (imageNodes.length > 0) { - if (!this.imageEmbedModel) { - throw new Error( - "Cannot calculate image nodes embedding without 'imageEmbedModel' set", - ); + const nodeMap = splitNodesByType(nodes); + for (const type in nodeMap) { + const nodes = nodeMap[type as ModalityType]; + const embedModel = + this.embedModel ?? this.vectorStores[type as ModalityType]?.embedModel; + if (embedModel && nodes) { + await embedModel.transform(nodes, { + logProgress: options?.logProgress, + }); } - await this.imageEmbedModel.transform(imageNodes, { - logProgress: options?.logProgress, - }); } - await this.embedModel.transform(textNodes, { - logProgress: options?.logProgress, - }); return nodes; } @@ -210,14 +202,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { docStoreStrategy?: DocStoreStrategy; } = {}, ): Promise<VectorStoreIndex> { + args.storageContext = + args.storageContext ?? (await storageContextFromDefaults({})); + args.vectorStores = args.vectorStores ?? args.storageContext.vectorStores; args.docStoreStrategy = args.docStoreStrategy ?? // set doc store strategy defaults to the same as for the IngestionPipeline - (args.vectorStore + (args.vectorStores ? DocStoreStrategy.UPSERTS : DocStoreStrategy.DUPLICATES_ONLY); - args.storageContext = - args.storageContext ?? (await storageContextFromDefaults({})); args.serviceContext = args.serviceContext; const docStore = args.storageContext.docStore; @@ -226,10 +219,11 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { } // use doc store strategy to avoid duplicates + const vectorStores = Object.values(args.vectorStores ?? {}); const docStoreStrategy = createDocStoreStrategy( args.docStoreStrategy, docStore, - args.vectorStore, + vectorStores, ); args.nodes = await runTransformations( documents, @@ -243,20 +237,18 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return await this.init(args); } - static async fromVectorStore( - vectorStore: VectorStore, + static async fromVectorStores( + vectorStores: VectorStoreByType, serviceContext?: ServiceContext, - imageVectorStore?: VectorStore, ) { - if (!vectorStore.storesText) { + if (!vectorStores[ModalityType.TEXT]?.storesText) { throw new Error( "Cannot initialize from a vector store that does not store text", ); } const storageContext = await storageContextFromDefaults({ - vectorStore, - imageVectorStore, + vectorStores, }); const index = await this.init({ @@ -268,6 +260,16 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return index; } + static async fromVectorStore( + vectorStore: VectorStore, + serviceContext?: ServiceContext, + ) { + return this.fromVectorStores( + { [ModalityType.TEXT]: vectorStore }, + serviceContext, + ); + } + asRetriever( options?: Omit<VectorIndexRetrieverOptions, "index">, ): VectorIndexRetriever { @@ -301,11 +303,10 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { } protected async insertNodesToStore( - vectorStore: VectorStore, + newIds: string[], nodes: BaseNode[], + vectorStore: VectorStore, ): Promise<void> { - const newIds = await vectorStore.add(nodes); - // NOTE: if the vector store doesn't store text, // we need to add the nodes to the index struct and document store // NOTE: if the vector store keeps text, @@ -333,14 +334,11 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { return; } nodes = await this.getNodeEmbeddingResults(nodes, options); - const { imageNodes, textNodes } = splitNodesByType(nodes); - if (imageNodes.length > 0) { - if (!this.imageVectorStore) { - throw new Error("Cannot insert image nodes without image vector store"); - } - await this.insertNodesToStore(this.imageVectorStore, imageNodes); - } - await this.insertNodesToStore(this.vectorStore, textNodes); + await addNodesToVectorStores( + nodes, + this.vectorStores, + this.insertNodesToStore.bind(this), + ); await this.indexStore.addIndexStruct(this.indexStruct); } @@ -348,11 +346,9 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { refDocId: string, deleteFromDocStore: boolean = true, ): Promise<void> { - await this.deleteRefDocFromStore(this.vectorStore, refDocId); - if (this.imageVectorStore) { - await this.deleteRefDocFromStore(this.imageVectorStore, refDocId); + for (const vectorStore of Object.values(this.vectorStores)) { + await this.deleteRefDocFromStore(vectorStore, refDocId); } - if (deleteFromDocStore) { await this.docStore.deleteDocument(refDocId, false); } @@ -382,28 +378,34 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { * VectorIndexRetriever retrieves nodes from a VectorIndex. */ +type TopKMap = { [P in ModalityType]: number }; + export type VectorIndexRetrieverOptions = { index: VectorStoreIndex; similarityTopK?: number; - imageSimilarityTopK?: number; + topK?: TopKMap; }; export class VectorIndexRetriever implements BaseRetriever { index: VectorStoreIndex; - similarityTopK: number; - imageSimilarityTopK: number; + topK: TopKMap; serviceContext?: ServiceContext; - constructor({ - index, - similarityTopK, - imageSimilarityTopK, - }: VectorIndexRetrieverOptions) { + constructor({ index, similarityTopK, topK }: VectorIndexRetrieverOptions) { this.index = index; this.serviceContext = this.index.serviceContext; - this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K; - this.imageSimilarityTopK = imageSimilarityTopK ?? DEFAULT_SIMILARITY_TOP_K; + this.topK = topK ?? { + [ModalityType.TEXT]: similarityTopK ?? DEFAULT_SIMILARITY_TOP_K, + [ModalityType.IMAGE]: DEFAULT_SIMILARITY_TOP_K, + }; + } + + /** + * @deprecated, pass topK in constructor instead + */ + set similarityTopK(similarityTopK: number) { + this.topK[ModalityType.TEXT] = similarityTopK; } @wrapEventCaller @@ -416,13 +418,21 @@ export class VectorIndexRetriever implements BaseRetriever { query, }, }); - let nodesWithScores = await this.textRetrieve( - query, - preFilters as MetadataFilters, - ); - nodesWithScores = nodesWithScores.concat( - await this.textToImageRetrieve(query, preFilters as MetadataFilters), - ); + const vectorStores = this.index.vectorStores; + let nodesWithScores: NodeWithScore[] = []; + + for (const type in vectorStores) { + // TODO: add retrieval by using an image as query + const vectorStore: VectorStore = vectorStores[type as ModalityType]!; + nodesWithScores = nodesWithScores.concat( + await this.textRetrieve( + query, + type as ModalityType, + vectorStore, + preFilters as MetadataFilters, + ), + ); + } getCallbackManager().dispatchEvent("retrieve-end", { payload: { query, @@ -439,34 +449,17 @@ export class VectorIndexRetriever implements BaseRetriever { protected async textRetrieve( query: string, + type: ModalityType, + vectorStore: VectorStore, preFilters?: MetadataFilters, ): Promise<NodeWithScore[]> { - const options = {}; const q = await this.buildVectorStoreQuery( - this.index.embedModel, + this.index.embedModel ?? vectorStore.embedModel, query, - this.similarityTopK, + this.topK[type], preFilters, ); - const result = await this.index.vectorStore.query(q, options); - return this.buildNodeListFromQueryResult(result); - } - - private async textToImageRetrieve( - query: string, - preFilters?: MetadataFilters, - ) { - if (!this.index.imageEmbedModel || !this.index.imageVectorStore) { - // no-op if image embedding and vector store are not set - return []; - } - const q = await this.buildVectorStoreQuery( - this.index.imageEmbedModel, - query, - this.imageSimilarityTopK, - preFilters, - ); - const result = await this.index.imageVectorStore.query(q, preFilters); + const result = await vectorStore.query(q); return this.buildNodeListFromQueryResult(result); } @@ -479,9 +472,9 @@ export class VectorIndexRetriever implements BaseRetriever { const queryEmbedding = await embedModel.getQueryEmbedding(query); return { - queryEmbedding: queryEmbedding, + queryEmbedding, mode: VectorStoreQueryMode.DEFAULT, - similarityTopK: similarityTopK, + similarityTopK, filters: preFilters ?? undefined, }; } diff --git a/packages/core/src/ingestion/IngestionPipeline.ts b/packages/core/src/ingestion/IngestionPipeline.ts index d93867ce24cf6fbf0b03767a51092aa613389cb6..c52bb7b7d83497adeb9fe6562765af7a622ba033 100644 --- a/packages/core/src/ingestion/IngestionPipeline.ts +++ b/packages/core/src/ingestion/IngestionPipeline.ts @@ -1,5 +1,11 @@ import type { PlatformApiClient } from "@llamaindex/cloud"; -import type { BaseNode, Document } from "../Node.js"; +import { + ModalityType, + splitNodesByType, + type BaseNode, + type Document, + type Metadata, +} from "../Node.js"; import { getPipelineCreate } from "../cloud/config.js"; import { DEFAULT_PIPELINE_NAME, @@ -9,7 +15,10 @@ import { import { getAppBaseUrl, getClient } from "../cloud/utils.js"; import type { BaseReader } from "../readers/type.js"; import type { BaseDocumentStore } from "../storage/docStore/types.js"; -import type { VectorStore } from "../storage/vectorStore/types.js"; +import type { + VectorStore, + VectorStoreByType, +} from "../storage/vectorStore/types.js"; import { IngestionCache, getTransformationHash } from "./IngestionCache.js"; import { DocStoreStrategy, @@ -63,6 +72,7 @@ export class IngestionPipeline { documents?: Document[]; reader?: BaseReader; vectorStore?: VectorStore; + vectorStores?: VectorStoreByType; docStore?: BaseDocumentStore; docStoreStrategy: DocStoreStrategy = DocStoreStrategy.UPSERTS; cache?: IngestionCache; @@ -80,10 +90,13 @@ export class IngestionPipeline { if (!this.docStore) { this.docStoreStrategy = DocStoreStrategy.NONE; } + this.vectorStores = this.vectorStores ?? { + [ModalityType.TEXT]: this.vectorStore, + }; this._docStoreStrategy = createDocStoreStrategy( this.docStoreStrategy, this.docStore, - this.vectorStore, + Object.values(this.vectorStores), ); if (!this.disableCache) { this.cache = new IngestionCache(); @@ -123,9 +136,9 @@ export class IngestionPipeline { transformOptions, args, ); - if (this.vectorStore) { + if (this.vectorStores) { const nodesToAdd = nodes.filter((node) => node.embedding); - await this.vectorStore.add(nodesToAdd); + await addNodesToVectorStores(nodesToAdd, this.vectorStores); } return nodes; } @@ -176,3 +189,30 @@ export class IngestionPipeline { return pipeline.id; } } + +export async function addNodesToVectorStores( + nodes: BaseNode<Metadata>[], + vectorStores: VectorStoreByType, + nodesAdded?: ( + newIds: string[], + nodes: BaseNode<Metadata>[], + vectorStore: VectorStore, + ) => Promise<void>, +) { + const nodeMap = splitNodesByType(nodes); + for (const type in nodeMap) { + const nodes = nodeMap[type as ModalityType]; + if (nodes) { + const vectorStore = vectorStores[type as ModalityType]; + if (!vectorStore) { + throw new Error( + `Cannot insert nodes of type ${type} without assigned vector store`, + ); + } + const newIds = await vectorStore.add(nodes); + if (nodesAdded) { + await nodesAdded(newIds, nodes, vectorStore); + } + } + } +} diff --git a/packages/core/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts b/packages/core/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts index cfeae7f785b7f382b4cbe7c40ed468f6e6ab4ec4..941face51cc609f53830ba0cc8d523b1e2b6891a 100644 --- a/packages/core/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts +++ b/packages/core/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts @@ -10,11 +10,11 @@ import { classify } from "./classify.js"; */ export class UpsertsAndDeleteStrategy implements TransformComponent { protected docStore: BaseDocumentStore; - protected vectorStore?: VectorStore; + protected vectorStores?: VectorStore[]; - constructor(docStore: BaseDocumentStore, vectorStore?: VectorStore) { + constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { this.docStore = docStore; - this.vectorStore = vectorStore; + this.vectorStores = vectorStores; } async transform(nodes: BaseNode[]): Promise<BaseNode[]> { @@ -26,16 +26,20 @@ export class UpsertsAndDeleteStrategy implements TransformComponent { // remove unused docs for (const refDocId of unusedDocs) { await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStore) { - await this.vectorStore.delete(refDocId); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } // remove missing docs for (const docId of missingDocs) { await this.docStore.deleteDocument(docId, true); - if (this.vectorStore) { - await this.vectorStore.delete(docId); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(docId); + } } } diff --git a/packages/core/src/ingestion/strategies/UpsertsStrategy.ts b/packages/core/src/ingestion/strategies/UpsertsStrategy.ts index b562b1e426e7e6d8c1c7e639213e924298adbb11..b05df033a55c4680e1b963df3860e50589361cb3 100644 --- a/packages/core/src/ingestion/strategies/UpsertsStrategy.ts +++ b/packages/core/src/ingestion/strategies/UpsertsStrategy.ts @@ -9,11 +9,11 @@ import { classify } from "./classify.js"; */ export class UpsertsStrategy implements TransformComponent { protected docStore: BaseDocumentStore; - protected vectorStore?: VectorStore; + protected vectorStores?: VectorStore[]; - constructor(docStore: BaseDocumentStore, vectorStore?: VectorStore) { + constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { this.docStore = docStore; - this.vectorStore = vectorStore; + this.vectorStores = vectorStores; } async transform(nodes: BaseNode[]): Promise<BaseNode[]> { @@ -21,8 +21,10 @@ export class UpsertsStrategy implements TransformComponent { // remove unused docs for (const refDocId of unusedDocs) { await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStore) { - await this.vectorStore.delete(refDocId); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } // add non-duplicate docs diff --git a/packages/core/src/ingestion/strategies/index.ts b/packages/core/src/ingestion/strategies/index.ts index ed324e0dc4d770b277c65b63a8ba36a70e6465b1..13d916c920dd4b2b468cb821fa6aeba3c7194be0 100644 --- a/packages/core/src/ingestion/strategies/index.ts +++ b/packages/core/src/ingestion/strategies/index.ts @@ -28,7 +28,7 @@ class NoOpStrategy implements TransformComponent { export function createDocStoreStrategy( docStoreStrategy: DocStoreStrategy, docStore?: BaseDocumentStore, - vectorStore?: VectorStore, + vectorStores: VectorStore[] = [], ): TransformComponent { if (docStoreStrategy === DocStoreStrategy.NONE) { return new NoOpStrategy(); @@ -36,11 +36,11 @@ export function createDocStoreStrategy( if (!docStore) { throw new Error("docStore is required to create a doc store strategy."); } - if (vectorStore) { + if (vectorStores.length > 0) { if (docStoreStrategy === DocStoreStrategy.UPSERTS) { - return new UpsertsStrategy(docStore, vectorStore); + return new UpsertsStrategy(docStore, vectorStores); } else if (docStoreStrategy === DocStoreStrategy.UPSERTS_AND_DELETE) { - return new UpsertsAndDeleteStrategy(docStore, vectorStore); + return new UpsertsAndDeleteStrategy(docStore, vectorStores); } else if (docStoreStrategy === DocStoreStrategy.DUPLICATES_ONLY) { return new DuplicatesStrategy(docStore); } else { diff --git a/packages/core/src/internal/settings/EmbedModel.ts b/packages/core/src/internal/settings/EmbedModel.ts new file mode 100644 index 0000000000000000000000000000000000000000..064ce1b7df32432993bb93c4c968245c8934c19e --- /dev/null +++ b/packages/core/src/internal/settings/EmbedModel.ts @@ -0,0 +1,24 @@ +import { AsyncLocalStorage } from "@llamaindex/env"; +import { OpenAIEmbedding } from "../../embeddings/OpenAIEmbedding.js"; +import type { BaseEmbedding } from "../../embeddings/index.js"; + +const embeddedModelAsyncLocalStorage = new AsyncLocalStorage<BaseEmbedding>(); +let globalEmbeddedModel: BaseEmbedding | null = null; + +export function getEmbeddedModel(): BaseEmbedding { + if (globalEmbeddedModel === null) { + globalEmbeddedModel = new OpenAIEmbedding(); + } + return embeddedModelAsyncLocalStorage.getStore() ?? globalEmbeddedModel; +} + +export function setEmbeddedModel(embeddedModel: BaseEmbedding) { + globalEmbeddedModel = embeddedModel; +} + +export function withEmbeddedModel<Result>( + embeddedModel: BaseEmbedding, + fn: () => Result, +): Result { + return embeddedModelAsyncLocalStorage.run(embeddedModel, fn); +} diff --git a/packages/core/src/storage/StorageContext.ts b/packages/core/src/storage/StorageContext.ts index 8872a67c597b831aed201dbd0ce44359d8d2dc30..a34cf04c64c137c546299e0852377d92af9ffcde 100644 --- a/packages/core/src/storage/StorageContext.ts +++ b/packages/core/src/storage/StorageContext.ts @@ -1,4 +1,6 @@ import { path } from "@llamaindex/env"; +import { ModalityType, ObjectType } from "../Node.js"; +import { ClipEmbedding } from "../embeddings/ClipEmbedding.js"; import { DEFAULT_IMAGE_VECTOR_NAMESPACE, DEFAULT_NAMESPACE, @@ -8,20 +10,19 @@ import type { BaseDocumentStore } from "./docStore/types.js"; import { SimpleIndexStore } from "./indexStore/SimpleIndexStore.js"; import type { BaseIndexStore } from "./indexStore/types.js"; import { SimpleVectorStore } from "./vectorStore/SimpleVectorStore.js"; -import type { VectorStore } from "./vectorStore/types.js"; +import type { VectorStore, VectorStoreByType } from "./vectorStore/types.js"; export interface StorageContext { docStore: BaseDocumentStore; indexStore: BaseIndexStore; - vectorStore: VectorStore; - imageVectorStore?: VectorStore; + vectorStores: VectorStoreByType; } -export type BuilderParams = { +type BuilderParams = { docStore: BaseDocumentStore; indexStore: BaseIndexStore; vectorStore: VectorStore; - imageVectorStore: VectorStore; + vectorStores: VectorStoreByType; storeImages: boolean; persistDir: string; }; @@ -30,34 +31,43 @@ export async function storageContextFromDefaults({ docStore, indexStore, vectorStore, - imageVectorStore, + vectorStores, storeImages, persistDir, }: Partial<BuilderParams>): Promise<StorageContext> { + vectorStores = vectorStores ?? {}; if (!persistDir) { - docStore = docStore || new SimpleDocumentStore(); - indexStore = indexStore || new SimpleIndexStore(); - vectorStore = vectorStore || new SimpleVectorStore(); - imageVectorStore = storeImages ? new SimpleVectorStore() : imageVectorStore; + docStore = docStore ?? new SimpleDocumentStore(); + indexStore = indexStore ?? new SimpleIndexStore(); + if (!(ModalityType.TEXT in vectorStores)) { + vectorStores[ModalityType.TEXT] = vectorStore ?? new SimpleVectorStore(); + } + if (storeImages && !(ModalityType.IMAGE in vectorStores)) { + vectorStores[ModalityType.IMAGE] = new SimpleVectorStore({ + embedModel: new ClipEmbedding(), + }); + } } else { docStore = docStore || (await SimpleDocumentStore.fromPersistDir(persistDir, DEFAULT_NAMESPACE)); indexStore = indexStore || (await SimpleIndexStore.fromPersistDir(persistDir)); - vectorStore = - vectorStore || (await SimpleVectorStore.fromPersistDir(persistDir)); - imageVectorStore = storeImages - ? await SimpleVectorStore.fromPersistDir( - path.join(persistDir, DEFAULT_IMAGE_VECTOR_NAMESPACE), - ) - : imageVectorStore; + if (!(ObjectType.TEXT in vectorStores)) { + vectorStores[ModalityType.TEXT] = + vectorStore ?? (await SimpleVectorStore.fromPersistDir(persistDir)); + } + if (storeImages && !(ObjectType.IMAGE in vectorStores)) { + vectorStores[ModalityType.IMAGE] = await SimpleVectorStore.fromPersistDir( + path.join(persistDir, DEFAULT_IMAGE_VECTOR_NAMESPACE), + new ClipEmbedding(), + ); + } } return { docStore, indexStore, - vectorStore, - imageVectorStore, + vectorStores, }; } diff --git a/packages/core/src/storage/vectorStore/AstraDBVectorStore.ts b/packages/core/src/storage/vectorStore/AstraDBVectorStore.ts index e0d083513bf62559b90463ad376d50ddfb46f873..ed732218e5a70b48e37d86db16bb9d71a9f1ed69 100644 --- a/packages/core/src/storage/vectorStore/AstraDBVectorStore.ts +++ b/packages/core/src/storage/vectorStore/AstraDBVectorStore.ts @@ -2,14 +2,19 @@ import { Collection, DataAPIClient, Db } from "@datastax/astra-db-ts"; import { getEnv } from "@llamaindex/env"; import type { BaseNode } from "../../Node.js"; import { MetadataMode } from "../../Node.js"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import { metadataDictToNode, nodeToMetadata } from "./utils.js"; -export class AstraDBVectorStore implements VectorStore { +export class AstraDBVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; flatMetadata: boolean = true; @@ -27,8 +32,9 @@ export class AstraDBVectorStore implements VectorStore { endpoint: string; namespace?: string; }; - }, + } & Partial<IEmbedModel>, ) { + super(init?.embedModel); const token = init?.params?.token ?? getEnv("ASTRA_DB_APPLICATION_TOKEN"); const endpoint = init?.params?.endpoint ?? getEnv("ASTRA_DB_API_ENDPOINT"); diff --git a/packages/core/src/storage/vectorStore/ChromaVectorStore.ts b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts index 095d5ad53d92e8433cd5dc22cbe242350bc82597..a8b27b0f09b21914153dc66140421de2642e43a1 100644 --- a/packages/core/src/storage/vectorStore/ChromaVectorStore.ts +++ b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts @@ -9,12 +9,14 @@ import type { import { ChromaClient, IncludeEnum } from "chromadb"; import type { BaseNode } from "../../Node.js"; import { MetadataMode } from "../../Node.js"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + VectorStoreQueryMode, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; -import { VectorStoreQueryMode } from "./types.js"; import { metadataDictToNode, nodeToMetadata } from "./utils.js"; type ChromaDeleteOptions = { @@ -28,7 +30,10 @@ type ChromaQueryOptions = { const DEFAULT_TEXT_KEY = "text"; -export class ChromaVectorStore implements VectorStore { +export class ChromaVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; flatMetadata: boolean = true; textKey: string; @@ -36,11 +41,14 @@ export class ChromaVectorStore implements VectorStore { private collection: Collection | null = null; private collectionName: string; - constructor(init: { - collectionName: string; - textKey?: string; - chromaClientParams?: ChromaClientParams; - }) { + constructor( + init: { + collectionName: string; + textKey?: string; + chromaClientParams?: ChromaClientParams; + } & Partial<IEmbedModel>, + ) { + super(init.embedModel); this.collectionName = init.collectionName; this.chromaClient = new ChromaClient(init.chromaClientParams); this.textKey = init.textKey ?? DEFAULT_TEXT_KEY; diff --git a/packages/core/src/storage/vectorStore/MilvusVectorStore.ts b/packages/core/src/storage/vectorStore/MilvusVectorStore.ts index 0a5f3f1077bad53fa0bf6fbb5e74295fa6b0cc79..94848802eb4cbc3cd5fba83a9981ffcf36605f0e 100644 --- a/packages/core/src/storage/vectorStore/MilvusVectorStore.ts +++ b/packages/core/src/storage/vectorStore/MilvusVectorStore.ts @@ -8,14 +8,19 @@ import { type RowData, } from "@zilliz/milvus2-sdk-node"; import { BaseNode, MetadataMode, type Metadata } from "../../Node.js"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import { metadataDictToNode, nodeToMetadata } from "./utils.js"; -export class MilvusVectorStore implements VectorStore { +export class MilvusVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ public storesText: boolean = true; public isEmbeddingQuery?: boolean; private flatMetadata: boolean = true; @@ -30,21 +35,23 @@ export class MilvusVectorStore implements VectorStore { private embeddingKey: string; constructor( - init?: Partial<{ milvusClient: MilvusClient }> & { - params?: { - configOrAddress: ClientConfig | string; - ssl?: boolean; - username?: string; - password?: string; - channelOptions?: ChannelOptions; - }; - collection?: string; - idKey?: string; - contentKey?: string; - metadataKey?: string; - embeddingKey?: string; - }, + init?: Partial<{ milvusClient: MilvusClient }> & + Partial<IEmbedModel> & { + params?: { + configOrAddress: ClientConfig | string; + ssl?: boolean; + username?: string; + password?: string; + channelOptions?: ChannelOptions; + }; + collection?: string; + idKey?: string; + contentKey?: string; + metadataKey?: string; + embeddingKey?: string; + }, ) { + super(init?.embedModel); if (init?.milvusClient) { this.milvusClient = init.milvusClient; } else { diff --git a/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts b/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts index b37fe67695cc17d2e49374dacd83881929fb42f4..7d0fbafca19d1322ef69a6f8f32431516657d884 100644 --- a/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts +++ b/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts @@ -3,11 +3,12 @@ import type { BulkWriteOptions, Collection } from "mongodb"; import { MongoClient } from "mongodb"; import type { BaseNode } from "../../Node.js"; import { MetadataMode } from "../../Node.js"; -import type { - MetadataFilters, - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type MetadataFilters, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import { metadataDictToNode, nodeToMetadata } from "./utils.js"; @@ -23,7 +24,10 @@ function toMongoDBFilter( } // MongoDB Atlas Vector Store class implementing VectorStore -export class MongoDBAtlasVectorSearch implements VectorStore { +export class MongoDBAtlasVectorSearch + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; flatMetadata: boolean = true; @@ -42,6 +46,7 @@ export class MongoDBAtlasVectorSearch implements VectorStore { collectionName: string; }, ) { + super(); if (init.mongodbClient) { this.mongodbClient = init.mongodbClient; } else { diff --git a/packages/core/src/storage/vectorStore/PGVectorStore.ts b/packages/core/src/storage/vectorStore/PGVectorStore.ts index b46a2a7b161707fa3cf7a947bffdf0d5ad8d22e2..e06810755de6867d2570fe6d629f3cf643feb0ee 100644 --- a/packages/core/src/storage/vectorStore/PGVectorStore.ts +++ b/packages/core/src/storage/vectorStore/PGVectorStore.ts @@ -1,9 +1,11 @@ import type pg from "pg"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import type { BaseNode, Metadata } from "../../Node.js"; @@ -16,7 +18,10 @@ export const PGVECTOR_TABLE = "llamaindex_embedding"; * Provides support for writing and querying vector data in Postgres. * Note: Can't be used with data created using the Python version of the vector store (https://docs.llamaindex.ai/en/stable/examples/vector_stores/postgres.html) */ -export class PGVectorStore implements VectorStore { +export class PGVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; private collection: string = ""; @@ -44,12 +49,15 @@ export class PGVectorStore implements VectorStore { * @param {string} config.connectionString - The connection string (optional). * @param {number} config.dimensions - The dimensions of the embedding model. */ - constructor(config?: { - schemaName?: string; - tableName?: string; - connectionString?: string; - dimensions?: number; - }) { + constructor( + config?: { + schemaName?: string; + tableName?: string; + connectionString?: string; + dimensions?: number; + } & Partial<IEmbedModel>, + ) { + super(config?.embedModel); this.schemaName = config?.schemaName ?? PGVECTOR_SCHEMA; this.tableName = config?.tableName ?? PGVECTOR_TABLE; this.connectionString = config?.connectionString; diff --git a/packages/core/src/storage/vectorStore/PineconeVectorStore.ts b/packages/core/src/storage/vectorStore/PineconeVectorStore.ts index cf5f789ac2373b5f9e6344cae36cc2ad7b09eded..4311e56687e2fc6b8f9033b27c775ea27417ad32 100644 --- a/packages/core/src/storage/vectorStore/PineconeVectorStore.ts +++ b/packages/core/src/storage/vectorStore/PineconeVectorStore.ts @@ -1,9 +1,11 @@ -import type { - ExactMatchFilter, - MetadataFilters, - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type ExactMatchFilter, + type IEmbedModel, + type MetadataFilters, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import { getEnv } from "@llamaindex/env"; @@ -21,12 +23,15 @@ type PineconeParams = { chunkSize?: number; namespace?: string; textKey?: string; -}; +} & IEmbedModel; /** - * Provides support for writing and querying vector data in Postgres. + * Provides support for writing and querying vector data in Pinecone. */ -export class PineconeVectorStore implements VectorStore { +export class PineconeVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; /* @@ -44,6 +49,7 @@ export class PineconeVectorStore implements VectorStore { textKey: string; constructor(params?: PineconeParams) { + super(params?.embedModel); this.indexName = params?.indexName ?? getEnv("PINECONE_INDEX_NAME") ?? "llama"; this.namespace = params?.namespace ?? getEnv("PINECONE_NAMESPACE") ?? ""; diff --git a/packages/core/src/storage/vectorStore/QdrantVectorStore.ts b/packages/core/src/storage/vectorStore/QdrantVectorStore.ts index 61eecf7a905e5b0e60ef1db2c435e249178e6cd6..789ea6bd23c867b0162f1466d2a348671c99b310 100644 --- a/packages/core/src/storage/vectorStore/QdrantVectorStore.ts +++ b/packages/core/src/storage/vectorStore/QdrantVectorStore.ts @@ -1,8 +1,10 @@ import type { BaseNode } from "../../Node.js"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; import { QdrantClient } from "@qdrant/js-client-rest"; @@ -20,7 +22,7 @@ type QdrantParams = { url?: string; apiKey?: string; batchSize?: number; -}; +} & Partial<IEmbedModel>; type QuerySearchResult = { id: string; @@ -33,7 +35,10 @@ type QuerySearchResult = { /** * Qdrant vector store. */ -export class QdrantVectorStore implements VectorStore { +export class QdrantVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = true; batchSize: number; @@ -49,6 +54,7 @@ export class QdrantVectorStore implements VectorStore { * @param url Qdrant URL * @param apiKey Qdrant API key * @param batchSize Number of vectors to upload in a single batch + * @param embedModel Embedding model */ constructor({ collectionName, @@ -56,7 +62,9 @@ export class QdrantVectorStore implements VectorStore { url, apiKey, batchSize, + embedModel, }: QdrantParams) { + super(embedModel); if (!client && !url) { if (!url) { throw new Error("QdrantVectorStore requires url and collectionName"); diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index 0652c3214cc9443bc3de0d750dd9b6f89d0d7c90..35e801e46be10e8175dccc589be791017f38e74c 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -1,5 +1,6 @@ import { fs, path } from "@llamaindex/env"; import type { BaseNode } from "../../Node.js"; +import { BaseEmbedding } from "../../embeddings/index.js"; import { getTopKEmbeddings, getTopKEmbeddingsLearner, @@ -7,12 +8,14 @@ import { } from "../../embeddings/utils.js"; import { exists } from "../FileSystem.js"; import { DEFAULT_PERSIST_DIR } from "../constants.js"; -import type { - VectorStore, - VectorStoreQuery, - VectorStoreQueryResult, +import { + VectorStoreBase, + VectorStoreQueryMode, + type IEmbedModel, + type VectorStoreNoEmbedModel, + type VectorStoreQuery, + type VectorStoreQueryResult, } from "./types.js"; -import { VectorStoreQueryMode } from "./types.js"; const LEARNER_MODES = new Set<VectorStoreQueryMode>([ VectorStoreQueryMode.SVM, @@ -27,20 +30,25 @@ class SimpleVectorStoreData { textIdToRefDocId: Record<string, string> = {}; } -export class SimpleVectorStore implements VectorStore { +export class SimpleVectorStore + extends VectorStoreBase + implements VectorStoreNoEmbedModel +{ storesText: boolean = false; - private data: SimpleVectorStoreData = new SimpleVectorStoreData(); + private data: SimpleVectorStoreData; private persistPath: string | undefined; - constructor(data?: SimpleVectorStoreData) { - this.data = data || new SimpleVectorStoreData(); + constructor(init?: { data?: SimpleVectorStoreData } & Partial<IEmbedModel>) { + super(init?.embedModel); + this.data = init?.data || new SimpleVectorStoreData(); } static async fromPersistDir( persistDir: string = DEFAULT_PERSIST_DIR, + embedModel?: BaseEmbedding, ): Promise<SimpleVectorStore> { const persistPath = path.join(persistDir, "vector_store.json"); - return await SimpleVectorStore.fromPersistPath(persistPath); + return await SimpleVectorStore.fromPersistPath(persistPath, embedModel); } get client(): any { @@ -154,6 +162,7 @@ export class SimpleVectorStore implements VectorStore { static async fromPersistPath( persistPath: string, + embedModel?: BaseEmbedding, ): Promise<SimpleVectorStore> { const dirPath = path.dirname(persistPath); if (!(await exists(dirPath))) { @@ -173,16 +182,19 @@ export class SimpleVectorStore implements VectorStore { const data = new SimpleVectorStoreData(); data.embeddingDict = dataDict.embeddingDict ?? {}; data.textIdToRefDocId = dataDict.textIdToRefDocId ?? {}; - const store = new SimpleVectorStore(data); + const store = new SimpleVectorStore({ data, embedModel }); store.persistPath = persistPath; return store; } - static fromDict(saveDict: SimpleVectorStoreData): SimpleVectorStore { + static fromDict( + saveDict: SimpleVectorStoreData, + embedModel?: BaseEmbedding, + ): SimpleVectorStore { const data = new SimpleVectorStoreData(); data.embeddingDict = saveDict.embeddingDict; data.textIdToRefDocId = saveDict.textIdToRefDocId; - return new SimpleVectorStore(data); + return new SimpleVectorStore({ data, embedModel }); } toDict(): SimpleVectorStoreData { diff --git a/packages/core/src/storage/vectorStore/types.ts b/packages/core/src/storage/vectorStore/types.ts index c14ec60778b97c5441e4ce3b213d44fe2719423e..34a1cf343b3ebe6f608d79eb3036ec08e428031e 100644 --- a/packages/core/src/storage/vectorStore/types.ts +++ b/packages/core/src/storage/vectorStore/types.ts @@ -1,4 +1,6 @@ -import type { BaseNode } from "../../Node.js"; +import type { BaseNode, ModalityType } from "../../Node.js"; +import type { BaseEmbedding } from "../../embeddings/types.js"; +import { getEmbeddedModel } from "../../internal/settings/EmbedModel.js"; export interface VectorStoreQueryResult { nodes?: BaseNode[]; @@ -56,7 +58,7 @@ export interface VectorStoreQuery { mmrThreshold?: number; } -export interface VectorStore { +export interface VectorStoreNoEmbedModel { storesText: boolean; isEmbeddingQuery?: boolean; client(): any; @@ -67,3 +69,23 @@ export interface VectorStore { options?: any, ): Promise<VectorStoreQueryResult>; } + +export interface IEmbedModel { + embedModel: BaseEmbedding; +} + +export interface VectorStore extends VectorStoreNoEmbedModel, IEmbedModel {} + +// Supported types of vector stores (for each modality) + +export type VectorStoreByType = { + [P in ModalityType]?: VectorStore; +}; + +export abstract class VectorStoreBase implements IEmbedModel { + embedModel: BaseEmbedding; + + protected constructor(embedModel?: BaseEmbedding) { + this.embedModel = embedModel ?? getEmbeddedModel(); + } +} diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts index 630a182425820a0df99f7d77293dff36f02d037d..d2a97c5f1aa65c68a7913f6539943b8f4d634c87 100644 --- a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -1,5 +1,5 @@ import type { ImageNode } from "../Node.js"; -import { MetadataMode, splitNodesByType } from "../Node.js"; +import { MetadataMode, ModalityType, splitNodesByType } from "../Node.js"; import { Response } from "../Response.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; @@ -63,7 +63,10 @@ export class MultiModalResponseSynthesizer throw new Error("streaming not implemented"); } const nodes = nodesWithScore.map(({ node }) => node); - const { imageNodes, textNodes } = splitNodesByType(nodes); + const nodeMap = splitNodesByType(nodes); + const imageNodes: ImageNode[] = + (nodeMap[ModalityType.IMAGE] as ImageNode[]) ?? []; + const textNodes = nodeMap[ModalityType.TEXT] ?? []; const textChunks = textNodes.map((node) => node.getContent(this.metadataMode), ); diff --git a/packages/core/tests/indices/VectorStoreIndex.test.ts b/packages/core/tests/indices/VectorStoreIndex.test.ts index e2578adde509b67a917ac645d63678059016cafa..82acb55c52e9ed59118025a6b3607e27efdc2287 100644 --- a/packages/core/tests/indices/VectorStoreIndex.test.ts +++ b/packages/core/tests/indices/VectorStoreIndex.test.ts @@ -1,9 +1,5 @@ import type { ServiceContext, StorageContext } from "llamaindex"; -import { - Document, - VectorStoreIndex, - storageContextFromDefaults, -} from "llamaindex"; +import { Document, VectorStoreIndex } from "llamaindex"; import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index"; import { mkdtemp, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; @@ -13,6 +9,7 @@ import { afterAll, beforeAll, describe, expect, test } from "vitest"; const testDir = await mkdtemp(join(tmpdir(), "test-")); import { mockServiceContext } from "../utility/mockServiceContext.js"; +import { mockStorageContext } from "../utility/mockStorageContext.js"; describe("VectorStoreIndex", () => { let serviceContext: ServiceContext; @@ -24,9 +21,7 @@ describe("VectorStoreIndex", () => { beforeAll(async () => { serviceContext = mockServiceContext(); - storageContext = await storageContextFromDefaults({ - persistDir: testDir, - }); + storageContext = await mockStorageContext(testDir); testStrategy = async ( strategy: DocStoreStrategy, runs: number = 2, diff --git a/packages/core/tests/utility/mockStorageContext.ts b/packages/core/tests/utility/mockStorageContext.ts new file mode 100644 index 0000000000000000000000000000000000000000..4a8c0f8802ff61b31d2fb2c99856a6f74ecb7993 --- /dev/null +++ b/packages/core/tests/utility/mockStorageContext.ts @@ -0,0 +1,14 @@ +import { OpenAIEmbedding, storageContextFromDefaults } from "llamaindex"; + +import { mockEmbeddingModel } from "./mockOpenAI.js"; + +export async function mockStorageContext(testDir: string) { + const storageContext = await storageContextFromDefaults({ + persistDir: testDir, + }); + for (const store of Object.values(storageContext.vectorStores)) { + store.embedModel = new OpenAIEmbedding(); + mockEmbeddingModel(store.embedModel as OpenAIEmbedding); + } + return storageContext; +}