diff --git a/.changeset/honest-readers-tease.md b/.changeset/honest-readers-tease.md new file mode 100644 index 0000000000000000000000000000000000000000..1f6aa0066aa4e455b7cdcb788b886595578c3422 --- /dev/null +++ b/.changeset/honest-readers-tease.md @@ -0,0 +1,7 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +"@llamaindex/core-e2e": patch +--- + +feat: support transform component callable diff --git a/packages/core/src/embeddings/base.ts b/packages/core/src/embeddings/base.ts index a27c5a04e3910816eabfd756aea729545d97d844..7ce0645fcc43faf914420505ac287cc04eef9021 100644 --- a/packages/core/src/embeddings/base.ts +++ b/packages/core/src/embeddings/base.ts @@ -1,7 +1,6 @@ import { type Tokenizers } from "@llamaindex/env"; import type { MessageContentDetail } from "../llms"; -import type { TransformComponent } from "../schema"; -import { BaseNode, MetadataMode } from "../schema"; +import { BaseNode, MetadataMode, TransformComponent } from "../schema"; import { extractSingleText } from "../utils"; import { truncateMaxTokens } from "./tokenizer.js"; import { SimilarityType, similarity } from "./utils.js"; @@ -20,10 +19,29 @@ export type BaseEmbeddingOptions = { logProgress?: boolean; }; -export abstract class BaseEmbedding implements TransformComponent { +export abstract class BaseEmbedding extends TransformComponent { embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedInfo?: EmbeddingInfo; + constructor() { + super( + async ( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ): Promise<BaseNode[]> => { + const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); + + const embeddings = await this.getTextEmbeddingsBatch(texts, options); + + for (let i = 0; i < nodes.length; i++) { + nodes[i].embedding = embeddings[i]; + } + + return nodes; + }, + ); + } + similarity( embedding1: number[], embedding2: number[], @@ -76,21 +94,6 @@ export abstract class BaseEmbedding implements TransformComponent { ); } - async transform( - nodes: BaseNode[], - options?: BaseEmbeddingOptions, - ): Promise<BaseNode[]> { - const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - - const embeddings = await this.getTextEmbeddingsBatch(texts, options); - - for (let i = 0; i < nodes.length; i++) { - nodes[i].embedding = embeddings[i]; - } - - return nodes; - } - truncateMaxTokens(input: string[]): string[] { return input.map((s) => { // truncate to max tokens diff --git a/packages/core/src/node-parser/base.ts b/packages/core/src/node-parser/base.ts index 9aeb61f1a6797d68b03a4d07e543ad32e05791b7..ccea8b3a8809a742e967c8f9614ced1b755354de 100644 --- a/packages/core/src/node-parser/base.ts +++ b/packages/core/src/node-parser/base.ts @@ -5,13 +5,19 @@ import { MetadataMode, NodeRelationship, TextNode, - type TransformComponent, + TransformComponent, } from "../schema"; -export abstract class NodeParser implements TransformComponent { +export abstract class NodeParser extends TransformComponent { includeMetadata: boolean = true; includePrevNextRel: boolean = true; + constructor() { + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + return this.getNodesFromDocuments(nodes as TextNode[]); + }); + } + protected postProcessParsedNodes( nodes: TextNode[], parentDocMap: Map<string, TextNode>, @@ -90,10 +96,6 @@ export abstract class NodeParser implements TransformComponent { return nodes; } - - async transform(nodes: BaseNode[], options?: {}): Promise<BaseNode[]> { - return this.getNodesFromDocuments(nodes as TextNode[]); - } } export abstract class TextSplitter extends NodeParser { diff --git a/packages/core/src/schema/index.ts b/packages/core/src/schema/index.ts index b1f8919781b030a1294981c7459727b1921c79e2..da1924211d5cbf5c408883a528821668a71a2d2d 100644 --- a/packages/core/src/schema/index.ts +++ b/packages/core/src/schema/index.ts @@ -1,4 +1,4 @@ export * from "./node"; -export type { TransformComponent } from "./type"; +export { TransformComponent } from "./type"; export { EngineResponse } from "./type/engine–response"; export * from "./zod"; diff --git a/packages/core/src/schema/type.ts b/packages/core/src/schema/type.ts index 688c5ce88dba0e05b54595b37ab84f661501e843..7aa3add8a803f21f3dbe0c616ef14b62d2e99554 100644 --- a/packages/core/src/schema/type.ts +++ b/packages/core/src/schema/type.ts @@ -1,8 +1,30 @@ +import { randomUUID } from "@llamaindex/env"; import type { BaseNode } from "./node"; -export interface TransformComponent { - transform<Options extends Record<string, unknown>>( +interface TransformComponentSignature { + <Options extends Record<string, unknown>>( nodes: BaseNode[], options?: Options, ): Promise<BaseNode[]>; } + +export interface TransformComponent extends TransformComponentSignature { + id: string; +} + +export class TransformComponent { + constructor(transformFn: TransformComponentSignature) { + Object.defineProperties( + transformFn, + Object.getOwnPropertyDescriptors(this.constructor.prototype), + ); + const transform = function transform( + ...args: Parameters<TransformComponentSignature> + ) { + return transformFn(...args); + }; + Reflect.setPrototypeOf(transform, new.target.prototype); + transform.id = randomUUID(); + return transform; + } +} diff --git a/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts index 2efe159afba5b407eaf072c96d30f4f1304921d6..c93aefd4e459f45267cbbb15d77782a5edca1b13 100644 --- a/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts +++ b/packages/llamaindex/e2e/fixtures/embeddings/OpenAIEmbedding.ts @@ -1,15 +1,26 @@ +import { TransformComponent } from "@llamaindex/core/schema"; import { + BaseEmbedding, BaseNode, SimilarityType, - type BaseEmbedding, type EmbeddingInfo, type MessageContentDetail, } from "llamaindex"; -export class OpenAIEmbedding implements BaseEmbedding { +export class OpenAIEmbedding + extends TransformComponent + implements BaseEmbedding +{ embedInfo?: EmbeddingInfo | undefined; embedBatchSize = 512; + constructor() { + super(async (nodes: BaseNode[], _options?: any): Promise<BaseNode[]> => { + nodes.forEach((node) => (node.embedding = [0])); + return nodes; + }); + } + async getQueryEmbedding(query: MessageContentDetail) { return [0]; } @@ -34,11 +45,6 @@ export class OpenAIEmbedding implements BaseEmbedding { return 1; } - async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { - nodes.forEach((node) => (node.embedding = [0])); - return nodes; - } - truncateMaxTokens(input: string[]): string[] { return input; } diff --git a/packages/llamaindex/src/extractors/types.ts b/packages/llamaindex/src/extractors/types.ts index d0af55f88d013bb5e4c1bfb0b61879ea26efbbb5..0bf97b3505a4bb95b46331848986fb5b885b81b3 100644 --- a/packages/llamaindex/src/extractors/types.ts +++ b/packages/llamaindex/src/extractors/types.ts @@ -1,11 +1,15 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; -import { MetadataMode, TextNode } from "@llamaindex/core/schema"; +import { + BaseNode, + MetadataMode, + TextNode, + TransformComponent, +} from "@llamaindex/core/schema"; import { defaultNodeTextTemplate } from "./prompts.js"; /* * Abstract class for all extractors. */ -export abstract class BaseExtractor implements TransformComponent { +export abstract class BaseExtractor extends TransformComponent { isTextNodeOnly: boolean = true; showProgress: boolean = true; metadataMode: MetadataMode = MetadataMode.ALL; @@ -13,16 +17,18 @@ export abstract class BaseExtractor implements TransformComponent { inPlace: boolean = true; numWorkers: number = 4; - abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>; - - async transform(nodes: BaseNode[], options?: any): Promise<BaseNode[]> { - return this.processNodes( - nodes, - options?.excludedEmbedMetadataKeys, - options?.excludedLlmMetadataKeys, - ); + constructor() { + super(async (nodes: BaseNode[], options?: any): Promise<BaseNode[]> => { + return this.processNodes( + nodes, + options?.excludedEmbedMetadataKeys, + options?.excludedLlmMetadataKeys, + ); + }); } + abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>; + /** * * @param nodes Nodes to extract metadata from. diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index 639556eebfe6e928674aea9b48d043c62c8197c2..c4902da9976099a9d27ee102ff999cd9283c25d6 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -172,7 +172,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { const embedModel = this.embedModel ?? this.vectorStores[type as ModalityType]?.embedModel; if (embedModel && nodes) { - await embedModel.transform(nodes, { + await embedModel(nodes, { logProgress: options?.logProgress, }); } diff --git a/packages/llamaindex/src/ingestion/IngestionCache.ts b/packages/llamaindex/src/ingestion/IngestionCache.ts index 7dea514c23e5b48c1eef6329a75a601dde202ade..353e565f85a6533984400e0138c20a4e403c5594 100644 --- a/packages/llamaindex/src/ingestion/IngestionCache.ts +++ b/packages/llamaindex/src/ingestion/IngestionCache.ts @@ -35,7 +35,7 @@ export function getTransformationHash( const transformString: string = transformToJSON(transform); const hash = createSHA256(); - hash.update(nodesStr + transformString); + hash.update(nodesStr + transformString + transform.id); return hash.digest(); } diff --git a/packages/llamaindex/src/ingestion/IngestionPipeline.ts b/packages/llamaindex/src/ingestion/IngestionPipeline.ts index c7174aa17809cd4aeb29a2062c7b189d98c1bbaf..fed97e9925e026caf2749e4fe2fdcf3f11c55f77 100644 --- a/packages/llamaindex/src/ingestion/IngestionPipeline.ts +++ b/packages/llamaindex/src/ingestion/IngestionPipeline.ts @@ -40,7 +40,7 @@ export async function runTransformations( nodes = [...nodesToRun]; } if (docStoreStrategy) { - nodes = await docStoreStrategy.transform(nodes); + nodes = await docStoreStrategy(nodes); } for (const transform of transformations) { if (cache) { @@ -49,11 +49,11 @@ export async function runTransformations( if (cachedNodes) { nodes = cachedNodes; } else { - nodes = await transform.transform(nodes, transformOptions); + nodes = await transform(nodes, transformOptions); await cache.put(hash, nodes); } } else { - nodes = await transform.transform(nodes, transformOptions); + nodes = await transform(nodes, transformOptions); } } return nodes; diff --git a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts index 5755370bb6c365c2892a4c4add015016c72e6ce3..c06451484b6457012fb4bb88e498cc9c536fa851 100644 --- a/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/DuplicatesStrategy.ts @@ -1,31 +1,30 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; /** * Handle doc store duplicates by checking all hashes. */ -export class DuplicatesStrategy implements TransformComponent { +export class DuplicatesStrategy extends TransformComponent { private docStore: BaseDocumentStore; constructor(docStore: BaseDocumentStore) { - this.docStore = docStore; - } + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const hashes = await this.docStore.getAllDocumentHashes(); + const currentHashes = new Set<string>(); + const nodesToRun: BaseNode[] = []; - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const hashes = await this.docStore.getAllDocumentHashes(); - const currentHashes = new Set<string>(); - const nodesToRun: BaseNode[] = []; - - for (const node of nodes) { - if (!(node.hash in hashes) && !currentHashes.has(node.hash)) { - await this.docStore.setDocumentHash(node.id_, node.hash); - nodesToRun.push(node); - currentHashes.add(node.hash); + for (const node of nodes) { + if (!(node.hash in hashes) && !currentHashes.has(node.hash)) { + await this.docStore.setDocumentHash(node.id_, node.hash); + nodesToRun.push(node); + currentHashes.add(node.hash); + } } - } - await this.docStore.addDocuments(nodesToRun, true); + await this.docStore.addDocuments(nodesToRun, true); - return nodesToRun; + return nodesToRun; + }); + this.docStore = docStore; } } diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts index 561c522a0d33240bc5c88bd5582c08274e6a3bc4..93c6aa49c26570c41d0733618e67b148fdb42126 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsAndDeleteStrategy.ts @@ -1,4 +1,4 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { classify } from "./classify.js"; @@ -7,43 +7,42 @@ import { classify } from "./classify.js"; * Handle docstore upserts by checking hashes and ids. * Identify missing docs and delete them from docstore and vector store */ -export class UpsertsAndDeleteStrategy implements TransformComponent { +export class UpsertsAndDeleteStrategy extends TransformComponent { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { - this.docStore = docStore; - this.vectorStores = vectorStores; - } + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const { dedupedNodes, missingDocs, unusedDocs } = await classify( + this.docStore, + nodes, + ); - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const { dedupedNodes, missingDocs, unusedDocs } = await classify( - this.docStore, - nodes, - ); - - // remove unused docs - for (const refDocId of unusedDocs) { - await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(refDocId); + // remove unused docs + for (const refDocId of unusedDocs) { + await this.docStore.deleteRefDoc(refDocId, false); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } - } - // remove missing docs - for (const docId of missingDocs) { - await this.docStore.deleteDocument(docId, true); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(docId); + // remove missing docs + for (const docId of missingDocs) { + await this.docStore.deleteDocument(docId, true); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(docId); + } } } - } - await this.docStore.addDocuments(dedupedNodes, true); + await this.docStore.addDocuments(dedupedNodes, true); - return dedupedNodes; + return dedupedNodes; + }); + this.docStore = docStore; + this.vectorStores = vectorStores; } } diff --git a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts index 69ae36beb50902c4d9d6cf6b9582c32b4b8715df..efeae560f8f743431aba8acfa4333a51a0a9583c 100644 --- a/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts +++ b/packages/llamaindex/src/ingestion/strategies/UpsertsStrategy.ts @@ -1,4 +1,4 @@ -import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; +import { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { classify } from "./classify.js"; @@ -6,28 +6,27 @@ import { classify } from "./classify.js"; /** * Handles doc store upserts by checking hashes and ids. */ -export class UpsertsStrategy implements TransformComponent { +export class UpsertsStrategy extends TransformComponent { protected docStore: BaseDocumentStore; protected vectorStores?: VectorStore[]; constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { - this.docStore = docStore; - this.vectorStores = vectorStores; - } - - async transform(nodes: BaseNode[]): Promise<BaseNode[]> { - const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes); - // remove unused docs - for (const refDocId of unusedDocs) { - await this.docStore.deleteRefDoc(refDocId, false); - if (this.vectorStores) { - for (const vectorStore of this.vectorStores) { - await vectorStore.delete(refDocId); + super(async (nodes: BaseNode[]): Promise<BaseNode[]> => { + const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes); + // remove unused docs + for (const refDocId of unusedDocs) { + await this.docStore.deleteRefDoc(refDocId, false); + if (this.vectorStores) { + for (const vectorStore of this.vectorStores) { + await vectorStore.delete(refDocId); + } } } - } - // add non-duplicate docs - await this.docStore.addDocuments(dedupedNodes, true); - return dedupedNodes; + // add non-duplicate docs + await this.docStore.addDocuments(dedupedNodes, true); + return dedupedNodes; + }); + this.docStore = docStore; + this.vectorStores = vectorStores; } } diff --git a/packages/llamaindex/src/ingestion/strategies/index.ts b/packages/llamaindex/src/ingestion/strategies/index.ts index 96765a758294b26068856f3987dbe453277f5e41..00cafe8fc104e9d3b4e8bcdb111b1ef2a8cd6d73 100644 --- a/packages/llamaindex/src/ingestion/strategies/index.ts +++ b/packages/llamaindex/src/ingestion/strategies/index.ts @@ -1,4 +1,4 @@ -import type { TransformComponent } from "@llamaindex/core/schema"; +import { TransformComponent } from "@llamaindex/core/schema"; import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js"; import { DuplicatesStrategy } from "./DuplicatesStrategy.js"; @@ -19,9 +19,9 @@ export enum DocStoreStrategy { NONE = "none", // no-op strategy } -class NoOpStrategy implements TransformComponent { - async transform(nodes: any[]): Promise<any[]> { - return nodes; +class NoOpStrategy extends TransformComponent { + constructor() { + super(async (nodes) => nodes); } }