From 91d02a4fc0c08944dfffbcd7740dde41aea2ab32 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Thu, 25 Jul 2024 11:19:57 -0700 Subject: [PATCH] feat: support transform component callable (#1072) --- .changeset/honest-readers-tease.md | 7 +++ packages/core/src/embeddings/base.ts | 39 +++++++------- packages/core/src/node-parser/base.ts | 14 ++--- packages/core/src/schema/index.ts | 2 +- packages/core/src/schema/type.ts | 26 ++++++++- .../fixtures/embeddings/OpenAIEmbedding.ts | 20 ++++--- packages/llamaindex/src/extractors/types.ts | 28 ++++++---- .../src/indices/vectorStore/index.ts | 2 +- .../src/ingestion/IngestionCache.ts | 2 +- .../src/ingestion/IngestionPipeline.ts | 6 +-- .../strategies/DuplicatesStrategy.ts | 33 ++++++------ .../strategies/UpsertsAndDeleteStrategy.ts | 53 +++++++++---------- .../ingestion/strategies/UpsertsStrategy.ts | 35 ++++++------ .../src/ingestion/strategies/index.ts | 8 +-- 14 files changed, 159 insertions(+), 116 deletions(-) create mode 100644 .changeset/honest-readers-tease.md diff --git a/.changeset/honest-readers-tease.md b/.changeset/honest-readers-tease.md new file mode 100644 index 000000000..1f6aa0066 --- /dev/null +++ b/.changeset/honest-readers-tease.md @@ -0,0 +1,7 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +"@llamaindex/core-e2e": patch +--- + +feat: support transform component callable diff --git a/packages/core/src/embeddings/base.ts b/packages/core/src/embeddings/base.ts index a27c5a04e..7ce0645fc 100644 --- a/packages/core/src/embeddings/base.ts +++ b/packages/core/src/embeddings/base.ts @@ -1,7 +1,6 @@ import { type Tokenizers } from "@llamaindex/env"; import type { MessageContentDetail } from "../llms"; -import type { TransformComponent } from "../schema"; -import { BaseNode, MetadataMode } from "../schema"; +import { BaseNode, MetadataMode, TransformComponent } from "../schema"; import { extractSingleText } from "../utils"; import { truncateMaxTokens } from "./tokenizer.js"; import { SimilarityType, similarity } from "./utils.js"; @@ -20,10 +19,29 @@ export type BaseEmbeddingOptions = { logProgress?: boolean; }; -export abstract class BaseEmbedding implements TransformComponent { +export abstract class BaseEmbedding extends TransformComponent { embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedInfo?: EmbeddingInfo; + constructor() { + super( + async ( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ): Promise<BaseNode[]> => { + const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); + + const embeddings = await this.getTextEmbeddingsBatch(texts, options); + + for (let i = 0; i < nodes.length; i++) { + nodes[i].embedding = embeddings[i]; + } + + return nodes; + }, + ); + } + similarity( embedding1: number[], embedding2: number[], @@ -76,21 +94,6 @@ export abstract class BaseEmbedding implements TransformComponent { ); } - async transform( - nodes: BaseNode[], - options?: BaseEmbeddingOptions, - ): Promise<BaseNode[]> { - const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - - const embeddings = await this.getTextEmbeddingsBatch(texts, options); - - for (let i = 0; i < nodes.length; i++) { - nodes[i].embedding = embeddings[i]; - } - - return nodes; - } - truncateMaxTokens(input: string[]): string[] { return input.map((s) => { // truncate to max tokens diff --git a/packages/core/src/node-parser/base.ts b/packages/core/src/node-parser/base.ts index 9aeb61f1a..ccea8b3a8 100644 --- a/packages/core/src/node-parser/base.ts +++ b/packages/core/src/node-parser/base.ts @@ -5,13 +5,19 @@ import { MetadataMode, NodeRelationship, TextNode, - type TransformComponent, + TransformComponent, } from "../schema"; -export abstract class NodeParser implements TransformComponent { +export abstract class NodeParser extends TransformComponent { includeMetadata: boolean = true; includePrevNextRel: boolean = true; + constructor() { + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + return this.getNodesFromDocuments(nodes as TextNode[]); + }); + } + protected postProcessParsedNodes( nodes: TextNode[], parentDocMap: Map<string, TextNode>, @@ -90,10 +96,6 @@ export abstract class NodeParser implements TransformComponent { return nodes; } - - async transform(nodes: BaseNode[], options?: {}): Promise<BaseNode[]> { - return this.getNodesFromDocuments(nodes as TextNode[]); - } } export abstract class TextSplitter extends NodeParser { diff --git a/packages/core/src/schema/index.ts b/packages/core/src/schema/index.ts index b1f891978..da1924211 100644 --- a/packages/core/src/schema/index.ts +++ b/packages/core/src/schema/index.ts @@ -1,4 +1,4 @@ export * from "./node"; -export type { TransformComponent } from "./type"; +export { TransformComponent } from "./type"; export { EngineResponse } from "./type/engine–response"; export * from "./zod"; diff --git a/packages/core/src/schema/type.ts b/packages/core/src/schema/type.ts index 688c5ce88..7aa3add8a 100644 --- a/packages/core/src/schema/type.ts +++ b/packages/core/src/schema/type.ts @@ -1,8 +1,30 @@ +import { randomUUID } from "@llamaindex/env"; import type { BaseNode } from "./node"; -export interface TransformComponent { - transform<Options extends Record<string, unknown>>( +interface TransformComponentSignature { + <Options extends Record<string, unknown>>( nodes: BaseNode[], options?: Options, ): Promise<BaseNode[]>; } + +export interface TransformComponent extends TransformComponentSignature { + id: string; +} + +export class TransformComponent { + constructor(transformFn: TransformComponentSignature) { + Object.defineProperties( + transformFn, + Object.getOwnPropertyDescriptors(this.constructor.prototype), + ); + const transform = function transform( + ...args: Parameters<TransformComponentSignature> + ) { + return transformFn(...args); + }; + Reflect.setPrototypeOf(transform, new.target.prototype); + transform.id = randomUUID(); + return transform; + } +} diff --git a/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts index 2efe159af..c93aefd4e 100644 --- a/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts +++ b/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts @@ -1,15 +1,26 @@ +import { TransformComponent } from "@llamaindex/core/schema"; import { + BaseEmbedding, BaseNode, SimilarityType, - type BaseEmbedding, type EmbeddingInfo, type MessageContentDetail, } from "llamaindex"; -export class OpenAIEmbedding implements BaseEmbedding { +export class OpenAIEmbedding + extends TransformComponent + implements BaseEmbedding +{ embedInfo?: EmbeddingInfo | undefined; embedBatchSize = 512; + constructor() { + super(async (nodes: BaseNode[], _options?: any): Promise<BaseNode[]> => { + nodes.forEach((node) => (node.embedding = [0])); + return nodes; + }); + } + async getQueryEmbedding(query: MessageContentDetail) { return [0]; } @@ -34,11 +45,6 @@ export class OpenAIEmbedding implements BaseEmbedding { return 1; } - async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { - nodes.forEach((node) => (node.embedding = [0])); - return nodes; - } - truncateMaxTokens(input: string[]): string[] { return input; } diff --git a/packages/llamaindex/src/extractors/types.ts b/packages/llamaindex/src/extractors/types.ts index d0af55f88..0bf97b350 100644 --- a/packages/llamaindex/src/extractors/types.ts +++ b/packages/llamaindex/src/extractors/types.ts @@ -1,11 +1,15 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; -import { MetadataMode, TextNode } from "@llamaindex/core/schema"; +import { + BaseNode, + MetadataMode, + TextNode, + TransformComponent, +} from "@llamaindex/core/schema"; import { defaultNodeTextTemplate } from "./prompts.js"; /* * Abstract class for all extractors. */ -export abstract class BaseExtractor implements TransformComponent { +export abstract class BaseExtractor extends TransformComponent { isTextNodeOnly: boolean = true; showProgress: boolean = true; metadataMode: MetadataMode = MetadataMode.ALL; @@ -13,16 +17,18 @@ export abstract class BaseExtractor implements TransformComponent { inPlace: boolean = true; numWorkers: number = 4; - abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>; - - async transform(nodes: BaseNode[], options?: any): Promise<BaseNode[]> { - return this.processNodes( - nodes, - options?.excludedEmbedMetadataKeys, - options?.excludedLlmMetadataKeys, - ); + constructor() { + super(async (nodes: BaseNode[], options?: any): Promise<BaseNode[]> => { + return this.processNodes( + nodes, + options?.excludedEmbedMetadataKeys, + options?.excludedLlmMetadataKeys, + ); + }); } + abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>; + /** * * @param nodes Nodes to extract metadata from. diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index 639556eeb..c4902da99 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -172,7 +172,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { const embedModel = this.embedModel ?? this.vectorStores[type as ModalityType]?.embedModel; if (embedModel && nodes) { - await embedModel.transform(nodes, { + await embedModel(nodes, { logProgress: options?.logProgress, }); } diff --git a/packages/llamaindex/src/ingestion/IngestionCache.ts b/packages/llamaindex/src/ingestion/IngestionCache.ts index 7dea514c2..353e565f8 100644 --- a/packages/llamaindex/src/ingestion/IngestionCache.ts +++ b/packages/llamaindex/src/ingestion/IngestionCache.ts @@ -35,7 +35,7 @@ export function getTransformationHash( const transformString: string = transformToJSON(transform); const hash = createSHA256(); - hash.update(nodesStr + transformString); + hash.update(nodesStr + transformString + transform.id); return hash.digest(); } diff --git a/packages/llamaindex/src/ingestion/IngestionPipeline.ts b/packages/llamaindex/src/ingestion/IngestionPipeline.ts index c7174aa17..fed97e992 100644 --- a/packages/llamaindex/src/ingestion/IngestionPipeline.ts +++ b/packages/llamaindex/src/ingestion/IngestionPipeline.ts @@ -40,7 +40,7 @@ export async function runTransformations( nodes = [...nodesToRun]; } if (docStoreStrategy) { - nodes = await docStoreStrategy.transform(nodes); + nodes = await docStoreStrategy(nodes); } for (const transform of transformations) { if (cache) { @@ -49,11 +49,11 @@ export async function runTransformations( if (cachedNodes) { nodes = cachedNodes; } else { - nodes = await transform.transform(nodes, transformOptions); + nodes = await transform(nodes, transformOptions); await cache.put(hash, nodes); } } else { - nodes = await transform.transform(nodes, transformOptions); + nodes = await transform(nodes, transformOptions); } } return nodes; diff --git a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts index 5755370bb..c06451484 100644 --- a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts @@ -1,31 +1,30 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; /** * Handle doc store duplicates by checking all hashes. */ -export class DuplicatesStrategy implements TransformComponent { +export class DuplicatesStrategy extends TransformComponent { private docStore: BaseDocumentStore; constructor(docStore: BaseDocumentStore) { - this.docStore = docStore; - } + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const hashes = await this.docStore.getAllDocumentHashes(); + const currentHashes = new Set<string>(); + const nodesToRun: BaseNode[] = []; - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const hashes = await this.docStore.getAllDocumentHashes(); - const currentHashes = new Set<string>(); - const nodesToRun: BaseNode[] = []; - - for (const node of nodes) { - if (!(node.hash in hashes) && !currentHashes.has(node.hash)) { - await this.docStore.setDocumentHash(node.id_, node.hash); - nodesToRun.push(node); - currentHashes.add(node.hash); + for (const node of nodes) { + if (!(node.hash in hashes) && !currentHashes.has(node.hash)) { + await this.docStore.setDocumentHash(node.id_, node.hash); + nodesToRun.push(node); + currentHashes.add(node.hash); + } } - } - await this.docStore.addDocuments(nodesToRun, true); + await this.docStore.addDocuments(nodesToRun, true); - return nodesToRun; + return nodesToRun; + }); + this.docStore = docStore; } } diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts index 561c522a0..93c6aa49c 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts @@ -1,4 +1,4 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { classify } from "./classify.js"; @@ -7,43 +7,42 @@ import { classify } from "./classify.js"; * Handle docstore upserts by checking hashes and ids. * Identify missing docs and delete them from docstore and vector store */ -export class UpsertsAndDeleteStrategy implements TransformComponent { +export class UpsertsAndDeleteStrategy extends TransformComponent { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { - this.docStore = docStore; - this.vectorStores = vectorStores; - } + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const { dedupedNodes, missingDocs, unusedDocs } = await classify( + this.docStore, + nodes, + ); - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const { dedupedNodes, missingDocs, unusedDocs } = await classify( - this.docStore, - nodes, - ); - - // remove unused docs - for (const refDocId of unusedDocs) { - await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(refDocId); + // remove unused docs + for (const refDocId of unusedDocs) { + await this.docStore.deleteRefDoc(refDocId, false); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } - } - // remove missing docs - for (const docId of missingDocs) { - await this.docStore.deleteDocument(docId, true); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(docId); + // remove missing docs + for (const docId of missingDocs) { + await this.docStore.deleteDocument(docId, true); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(docId); + } } } - } - await this.docStore.addDocuments(dedupedNodes, true); + await this.docStore.addDocuments(dedupedNodes, true); - return dedupedNodes; + return dedupedNodes; + }); + this.docStore = docStore; + this.vectorStores = vectorStores; } } diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts index 69ae36beb..efeae560f 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts @@ -1,4 +1,4 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { classify } from "./classify.js"; @@ -6,28 +6,27 @@ import { classify } from "./classify.js"; /** * Handles doc store upserts by checking hashes and ids. */ -export class UpsertsStrategy implements TransformComponent { +export class UpsertsStrategy extends TransformComponent { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { - this.docStore = docStore; - this.vectorStores = vectorStores; - } - - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes); - // remove unused docs - for (const refDocId of unusedDocs) { - await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(refDocId); + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes); + // remove unused docs + for (const refDocId of unusedDocs) { + await this.docStore.deleteRefDoc(refDocId, false); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } - } - // add non-duplicate docs - await this.docStore.addDocuments(dedupedNodes, true); - return dedupedNodes; + // add non-duplicate docs + await this.docStore.addDocuments(dedupedNodes, true); + return dedupedNodes; + }); + this.docStore = docStore; + this.vectorStores = vectorStores; } } diff --git a/packages/llamaindex/src/ingestion/strategies/index.ts b/packages/llamaindex/src/ingestion/strategies/index.ts index 96765a758..00cafe8fc 100644 --- a/packages/llamaindex/src/ingestion/strategies/index.ts +++ b/packages/llamaindex/src/ingestion/strategies/index.ts @@ -1,4 +1,4 @@ -import type { TransformComponent } from "@llamaindex/core/schema"; +import { TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { DuplicatesStrategy } from "./DuplicatesStrategy.js"; @@ -19,9 +19,9 @@ export enum DocStoreStrategy { NONE = "none", // no-op strategy } -class NoOpStrategy implements TransformComponent { - async transform(nodes: any[]): Promise<any[]> { - return nodes; +class NoOpStrategy extends TransformComponent { + constructor() { + super(async (nodes) => nodes); } } -- GitLab