Skip to content
Snippets Groups Projects
Unverified Commit dca02f72 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

refactor: VectorStoreIndex: use TransformerComponent to calc embeddings (#721)

parent b757fa9a
No related branches found
No related tags found
No related merge requests found
......@@ -154,7 +154,7 @@ If you need any of those classes, you have to import them instead directly. Here
import { PineconeVectorStore } from "@llamaindex/edge/storage/vectorStore/PineconeVectorStore";
```
As the `PDFReader` is not with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs:
As the `PDFReader` is not working with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs:
```typescript
import { SimpleDirectoryReader } from "@llamaindex/edge/readers/SimpleDirectoryReader";
......
......@@ -4,6 +4,7 @@ import {
VectorStoreIndex,
storageContextFromDefaults,
} from "llamaindex";
import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index";
import * as path from "path";
......@@ -31,6 +32,7 @@ async function generateDatasource() {
});
await VectorStoreIndex.fromDocuments(documents, {
storageContext,
docStoreStrategy: DocStoreStrategy.NONE,
});
});
console.log(`Storage successfully generated in ${ms / 1000}s.`);
......
......@@ -28,6 +28,7 @@ export class OpenAIEmbedding implements BaseEmbedding {
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
nodes.forEach((node) => (node.embedding = [0]));
return nodes;
}
}
import type { ImageType } from "../Node.js";
import { BaseEmbedding } from "./types.js";
import {
MetadataMode,
splitNodesByType,
type BaseNode,
type ImageType,
} from "../Node.js";
import { BaseEmbedding, batchEmbeddings } from "./types.js";
/*
* Base class for Multi Modal embeddings.
......@@ -8,9 +13,39 @@ import { BaseEmbedding } from "./types.js";
export abstract class MultiModalEmbedding extends BaseEmbedding {
abstract getImageEmbedding(images: ImageType): Promise<number[]>;
/**
* Optionally override this method to retrieve multiple image embeddings in a single request
* @param texts
*/
async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
return Promise.all(
images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
);
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
const { imageNodes, textNodes } = splitNodesByType(nodes);
const embeddings = await batchEmbeddings(
textNodes.map((node) => node.getContent(MetadataMode.EMBED)),
this.getTextEmbeddings.bind(this),
this.embedBatchSize,
_options,
);
for (let i = 0; i < textNodes.length; i++) {
textNodes[i].embedding = embeddings[i];
}
const imageEmbeddings = await batchEmbeddings(
imageNodes.map((n) => n.image),
this.getImageEmbeddings.bind(this),
this.embedBatchSize,
_options,
);
for (let i = 0; i < imageNodes.length; i++) {
imageNodes[i].embedding = imageEmbeddings[i];
}
return nodes;
}
}
......@@ -5,6 +5,8 @@ import { SimilarityType, similarity } from "./utils.js";
const DEFAULT_EMBED_BATCH_SIZE = 10;
type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>;
export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
......@@ -45,35 +47,18 @@ export abstract class BaseEmbedding implements TransformComponent {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const chunkSize = this.embedBatchSize;
const queue: string[] = texts;
const curBatch: string[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await this.getTextEmbeddings(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
return await batchEmbeddings(
texts,
this.getTextEmbeddings.bind(this),
this.embedBatchSize,
options,
);
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts);
const embeddings = await this.getTextEmbeddingsBatch(texts, _options);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
......@@ -82,3 +67,35 @@ export abstract class BaseEmbedding implements TransformComponent {
return nodes;
}
}
export async function batchEmbeddings<T>(
values: T[],
embedFunc: EmbedFunc<T>,
chunkSize: number,
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const queue: T[] = values;
const curBatch: T[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await embedFunc(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
}
......@@ -4,12 +4,7 @@ import type {
Metadata,
NodeWithScore,
} from "../../Node.js";
import {
ImageNode,
MetadataMode,
ObjectType,
splitNodesByType,
} from "../../Node.js";
import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js";
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
import type { ServiceContext } from "../../ServiceContext.js";
import {
......@@ -179,14 +174,21 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
nodes: BaseNode[],
options?: { logProgress?: boolean },
): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, {
const { imageNodes, textNodes } = splitNodesByType(nodes);
if (imageNodes.length > 0) {
if (!this.imageEmbedModel) {
throw new Error(
"Cannot calculate image nodes embedding without 'imageEmbedModel' set",
);
}
await this.imageEmbedModel.transform(imageNodes, {
logProgress: options?.logProgress,
});
}
await this.embedModel.transform(textNodes, {
logProgress: options?.logProgress,
});
return nodes.map((node, i) => {
node.embedding = embeddings[i];
return node;
});
return nodes;
}
/**
......@@ -324,25 +326,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
if (!nodes || nodes.length === 0) {
return;
}
nodes = await this.getNodeEmbeddingResults(nodes, options);
const { imageNodes, textNodes } = splitNodesByType(nodes);
if (imageNodes.length > 0) {
if (!this.imageVectorStore) {
throw new Error("Cannot insert image nodes without image vector store");
}
const imageNodesWithEmbedding = await this.getImageNodeEmbeddingResults(
imageNodes,
options,
);
await this.insertNodesToStore(
this.imageVectorStore,
imageNodesWithEmbedding,
);
await this.insertNodesToStore(this.imageVectorStore, imageNodes);
}
const embeddingResults = await this.getNodeEmbeddingResults(
textNodes,
options,
);
await this.insertNodesToStore(this.vectorStore, embeddingResults);
await this.insertNodesToStore(this.vectorStore, textNodes);
await this.indexStore.addIndexStruct(this.indexStruct);
}
......@@ -378,35 +370,6 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
await this.indexStore.addIndexStruct(this.indexStruct);
}
}
/**
* Calculates the embeddings for the given image nodes.
*
* @param nodes - An array of ImageNode objects representing the nodes for which embeddings are to be calculated.
* @param {Object} [options] - An optional object containing additional parameters.
* @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging).
*/
async getImageNodeEmbeddingResults(
nodes: ImageNode[],
options?: { logProgress?: boolean },
): Promise<ImageNode[]> {
if (!this.imageEmbedModel) {
return [];
}
const nodesWithEmbeddings: ImageNode[] = [];
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (options?.logProgress) {
console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
}
node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image);
nodesWithEmbeddings.push(node);
}
return nodesWithEmbeddings;
}
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment