Skip to content
Snippets Groups Projects
Unverified Commit dca02f72 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

refactor: VectorStoreIndex: use TransformerComponent to calc embeddings (#721)

parent b757fa9a
Branches
Tags
No related merge requests found
...@@ -154,7 +154,7 @@ If you need any of those classes, you have to import them instead directly. Here ...@@ -154,7 +154,7 @@ If you need any of those classes, you have to import them instead directly. Here
import { PineconeVectorStore } from "@llamaindex/edge/storage/vectorStore/PineconeVectorStore"; import { PineconeVectorStore } from "@llamaindex/edge/storage/vectorStore/PineconeVectorStore";
``` ```
As the `PDFReader` is not with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs: As the `PDFReader` is not working with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs:
```typescript ```typescript
import { SimpleDirectoryReader } from "@llamaindex/edge/readers/SimpleDirectoryReader"; import { SimpleDirectoryReader } from "@llamaindex/edge/readers/SimpleDirectoryReader";
......
...@@ -4,6 +4,7 @@ import { ...@@ -4,6 +4,7 @@ import {
VectorStoreIndex, VectorStoreIndex,
storageContextFromDefaults, storageContextFromDefaults,
} from "llamaindex"; } from "llamaindex";
import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index";
import * as path from "path"; import * as path from "path";
...@@ -31,6 +32,7 @@ async function generateDatasource() { ...@@ -31,6 +32,7 @@ async function generateDatasource() {
}); });
await VectorStoreIndex.fromDocuments(documents, { await VectorStoreIndex.fromDocuments(documents, {
storageContext, storageContext,
docStoreStrategy: DocStoreStrategy.NONE,
}); });
}); });
console.log(`Storage successfully generated in ${ms / 1000}s.`); console.log(`Storage successfully generated in ${ms / 1000}s.`);
......
...@@ -28,6 +28,7 @@ export class OpenAIEmbedding implements BaseEmbedding { ...@@ -28,6 +28,7 @@ export class OpenAIEmbedding implements BaseEmbedding {
} }
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
nodes.forEach((node) => (node.embedding = [0]));
return nodes; return nodes;
} }
} }
import type { ImageType } from "../Node.js"; import {
import { BaseEmbedding } from "./types.js"; MetadataMode,
splitNodesByType,
type BaseNode,
type ImageType,
} from "../Node.js";
import { BaseEmbedding, batchEmbeddings } from "./types.js";
/* /*
* Base class for Multi Modal embeddings. * Base class for Multi Modal embeddings.
...@@ -8,9 +13,39 @@ import { BaseEmbedding } from "./types.js"; ...@@ -8,9 +13,39 @@ import { BaseEmbedding } from "./types.js";
export abstract class MultiModalEmbedding extends BaseEmbedding { export abstract class MultiModalEmbedding extends BaseEmbedding {
abstract getImageEmbedding(images: ImageType): Promise<number[]>; abstract getImageEmbedding(images: ImageType): Promise<number[]>;
/**
* Optionally override this method to retrieve multiple image embeddings in a single request
* @param texts
*/
async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
return Promise.all( return Promise.all(
images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
); );
} }
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
const { imageNodes, textNodes } = splitNodesByType(nodes);
const embeddings = await batchEmbeddings(
textNodes.map((node) => node.getContent(MetadataMode.EMBED)),
this.getTextEmbeddings.bind(this),
this.embedBatchSize,
_options,
);
for (let i = 0; i < textNodes.length; i++) {
textNodes[i].embedding = embeddings[i];
}
const imageEmbeddings = await batchEmbeddings(
imageNodes.map((n) => n.image),
this.getImageEmbeddings.bind(this),
this.embedBatchSize,
_options,
);
for (let i = 0; i < imageNodes.length; i++) {
imageNodes[i].embedding = imageEmbeddings[i];
}
return nodes;
}
} }
...@@ -5,6 +5,8 @@ import { SimilarityType, similarity } from "./utils.js"; ...@@ -5,6 +5,8 @@ import { SimilarityType, similarity } from "./utils.js";
const DEFAULT_EMBED_BATCH_SIZE = 10; const DEFAULT_EMBED_BATCH_SIZE = 10;
type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>;
export abstract class BaseEmbedding implements TransformComponent { export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
...@@ -45,35 +47,18 @@ export abstract class BaseEmbedding implements TransformComponent { ...@@ -45,35 +47,18 @@ export abstract class BaseEmbedding implements TransformComponent {
logProgress?: boolean; logProgress?: boolean;
}, },
): Promise<Array<number[]>> { ): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = []; return await batchEmbeddings(
const chunkSize = this.embedBatchSize; texts,
this.getTextEmbeddings.bind(this),
const queue: string[] = texts; this.embedBatchSize,
options,
const curBatch: string[] = []; );
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await this.getTextEmbeddings(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
} }
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts); const embeddings = await this.getTextEmbeddingsBatch(texts, _options);
for (let i = 0; i < nodes.length; i++) { for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i]; nodes[i].embedding = embeddings[i];
...@@ -82,3 +67,35 @@ export abstract class BaseEmbedding implements TransformComponent { ...@@ -82,3 +67,35 @@ export abstract class BaseEmbedding implements TransformComponent {
return nodes; return nodes;
} }
} }
export async function batchEmbeddings<T>(
values: T[],
embedFunc: EmbedFunc<T>,
chunkSize: number,
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const queue: T[] = values;
const curBatch: T[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await embedFunc(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
}
...@@ -4,12 +4,7 @@ import type { ...@@ -4,12 +4,7 @@ import type {
Metadata, Metadata,
NodeWithScore, NodeWithScore,
} from "../../Node.js"; } from "../../Node.js";
import { import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js";
ImageNode,
MetadataMode,
ObjectType,
splitNodesByType,
} from "../../Node.js";
import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
import type { ServiceContext } from "../../ServiceContext.js"; import type { ServiceContext } from "../../ServiceContext.js";
import { import {
...@@ -179,14 +174,21 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -179,14 +174,21 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
nodes: BaseNode[], nodes: BaseNode[],
options?: { logProgress?: boolean }, options?: { logProgress?: boolean },
): Promise<BaseNode[]> { ): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); const { imageNodes, textNodes } = splitNodesByType(nodes);
const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, { if (imageNodes.length > 0) {
if (!this.imageEmbedModel) {
throw new Error(
"Cannot calculate image nodes embedding without 'imageEmbedModel' set",
);
}
await this.imageEmbedModel.transform(imageNodes, {
logProgress: options?.logProgress,
});
}
await this.embedModel.transform(textNodes, {
logProgress: options?.logProgress, logProgress: options?.logProgress,
}); });
return nodes.map((node, i) => { return nodes;
node.embedding = embeddings[i];
return node;
});
} }
/** /**
...@@ -324,25 +326,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -324,25 +326,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
if (!nodes || nodes.length === 0) { if (!nodes || nodes.length === 0) {
return; return;
} }
nodes = await this.getNodeEmbeddingResults(nodes, options);
const { imageNodes, textNodes } = splitNodesByType(nodes); const { imageNodes, textNodes } = splitNodesByType(nodes);
if (imageNodes.length > 0) { if (imageNodes.length > 0) {
if (!this.imageVectorStore) { if (!this.imageVectorStore) {
throw new Error("Cannot insert image nodes without image vector store"); throw new Error("Cannot insert image nodes without image vector store");
} }
const imageNodesWithEmbedding = await this.getImageNodeEmbeddingResults( await this.insertNodesToStore(this.imageVectorStore, imageNodes);
imageNodes,
options,
);
await this.insertNodesToStore(
this.imageVectorStore,
imageNodesWithEmbedding,
);
} }
const embeddingResults = await this.getNodeEmbeddingResults( await this.insertNodesToStore(this.vectorStore, textNodes);
textNodes,
options,
);
await this.insertNodesToStore(this.vectorStore, embeddingResults);
await this.indexStore.addIndexStruct(this.indexStruct); await this.indexStore.addIndexStruct(this.indexStruct);
} }
...@@ -378,35 +370,6 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -378,35 +370,6 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
await this.indexStore.addIndexStruct(this.indexStruct); await this.indexStore.addIndexStruct(this.indexStruct);
} }
} }
/**
* Calculates the embeddings for the given image nodes.
*
* @param nodes - An array of ImageNode objects representing the nodes for which embeddings are to be calculated.
* @param {Object} [options] - An optional object containing additional parameters.
* @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging).
*/
async getImageNodeEmbeddingResults(
nodes: ImageNode[],
options?: { logProgress?: boolean },
): Promise<ImageNode[]> {
if (!this.imageEmbedModel) {
return [];
}
const nodesWithEmbeddings: ImageNode[] = [];
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (options?.logProgress) {
console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
}
node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image);
nodesWithEmbeddings.push(node);
}
return nodesWithEmbeddings;
}
} }
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment