Skip to content
Snippets Groups Projects
Unverified Commit 91d02a4f authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

feat: support transform component callable (#1072)

parent 086b9403
No related branches found
No related tags found
No related merge requests found
Showing
with 159 additions and 116 deletions
---
"@llamaindex/core": patch
"llamaindex": patch
"@llamaindex/core-e2e": patch
---
feat: support transform component callable
import { type Tokenizers } from "@llamaindex/env"; import { type Tokenizers } from "@llamaindex/env";
import type { MessageContentDetail } from "../llms"; import type { MessageContentDetail } from "../llms";
import type { TransformComponent } from "../schema"; import { BaseNode, MetadataMode, TransformComponent } from "../schema";
import { BaseNode, MetadataMode } from "../schema";
import { extractSingleText } from "../utils"; import { extractSingleText } from "../utils";
import { truncateMaxTokens } from "./tokenizer.js"; import { truncateMaxTokens } from "./tokenizer.js";
import { SimilarityType, similarity } from "./utils.js"; import { SimilarityType, similarity } from "./utils.js";
...@@ -20,10 +19,29 @@ export type BaseEmbeddingOptions = { ...@@ -20,10 +19,29 @@ export type BaseEmbeddingOptions = {
logProgress?: boolean; logProgress?: boolean;
}; };
export abstract class BaseEmbedding implements TransformComponent { export abstract class BaseEmbedding extends TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
embedInfo?: EmbeddingInfo; embedInfo?: EmbeddingInfo;
constructor() {
super(
async (
nodes: BaseNode[],
options?: BaseEmbeddingOptions,
): Promise<BaseNode[]> => {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts, options);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
}
return nodes;
},
);
}
similarity( similarity(
embedding1: number[], embedding1: number[],
embedding2: number[], embedding2: number[],
...@@ -76,21 +94,6 @@ export abstract class BaseEmbedding implements TransformComponent { ...@@ -76,21 +94,6 @@ export abstract class BaseEmbedding implements TransformComponent {
); );
} }
async transform(
nodes: BaseNode[],
options?: BaseEmbeddingOptions,
): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts, options);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
}
return nodes;
}
truncateMaxTokens(input: string[]): string[] { truncateMaxTokens(input: string[]): string[] {
return input.map((s) => { return input.map((s) => {
// truncate to max tokens // truncate to max tokens
......
...@@ -5,13 +5,19 @@ import { ...@@ -5,13 +5,19 @@ import {
MetadataMode, MetadataMode,
NodeRelationship, NodeRelationship,
TextNode, TextNode,
type TransformComponent, TransformComponent,
} from "../schema"; } from "../schema";
export abstract class NodeParser implements TransformComponent { export abstract class NodeParser extends TransformComponent {
includeMetadata: boolean = true; includeMetadata: boolean = true;
includePrevNextRel: boolean = true; includePrevNextRel: boolean = true;
constructor() {
super(async (nodes: BaseNode[]): Promise<BaseNode[]> => {
return this.getNodesFromDocuments(nodes as TextNode[]);
});
}
protected postProcessParsedNodes( protected postProcessParsedNodes(
nodes: TextNode[], nodes: TextNode[],
parentDocMap: Map<string, TextNode>, parentDocMap: Map<string, TextNode>,
...@@ -90,10 +96,6 @@ export abstract class NodeParser implements TransformComponent { ...@@ -90,10 +96,6 @@ export abstract class NodeParser implements TransformComponent {
return nodes; return nodes;
} }
async transform(nodes: BaseNode[], options?: {}): Promise<BaseNode[]> {
return this.getNodesFromDocuments(nodes as TextNode[]);
}
} }
export abstract class TextSplitter extends NodeParser { export abstract class TextSplitter extends NodeParser {
......
export * from "./node"; export * from "./node";
export type { TransformComponent } from "./type"; export { TransformComponent } from "./type";
export { EngineResponse } from "./type/engine–response"; export { EngineResponse } from "./type/engine–response";
export * from "./zod"; export * from "./zod";
import { randomUUID } from "@llamaindex/env";
import type { BaseNode } from "./node"; import type { BaseNode } from "./node";
export interface TransformComponent { interface TransformComponentSignature {
transform<Options extends Record<string, unknown>>( <Options extends Record<string, unknown>>(
nodes: BaseNode[], nodes: BaseNode[],
options?: Options, options?: Options,
): Promise<BaseNode[]>; ): Promise<BaseNode[]>;
} }
export interface TransformComponent extends TransformComponentSignature {
id: string;
}
export class TransformComponent {
constructor(transformFn: TransformComponentSignature) {
Object.defineProperties(
transformFn,
Object.getOwnPropertyDescriptors(this.constructor.prototype),
);
const transform = function transform(
...args: Parameters<TransformComponentSignature>
) {
return transformFn(...args);
};
Reflect.setPrototypeOf(transform, new.target.prototype);
transform.id = randomUUID();
return transform;
}
}
import { TransformComponent } from "@llamaindex/core/schema";
import { import {
BaseEmbedding,
BaseNode, BaseNode,
SimilarityType, SimilarityType,
type BaseEmbedding,
type EmbeddingInfo, type EmbeddingInfo,
type MessageContentDetail, type MessageContentDetail,
} from "llamaindex"; } from "llamaindex";
export class OpenAIEmbedding implements BaseEmbedding { export class OpenAIEmbedding
extends TransformComponent
implements BaseEmbedding
{
embedInfo?: EmbeddingInfo | undefined; embedInfo?: EmbeddingInfo | undefined;
embedBatchSize = 512; embedBatchSize = 512;
constructor() {
super(async (nodes: BaseNode[], _options?: any): Promise<BaseNode[]> => {
nodes.forEach((node) => (node.embedding = [0]));
return nodes;
});
}
async getQueryEmbedding(query: MessageContentDetail) { async getQueryEmbedding(query: MessageContentDetail) {
return [0]; return [0];
} }
...@@ -34,11 +45,6 @@ export class OpenAIEmbedding implements BaseEmbedding { ...@@ -34,11 +45,6 @@ export class OpenAIEmbedding implements BaseEmbedding {
return 1; return 1;
} }
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
nodes.forEach((node) => (node.embedding = [0]));
return nodes;
}
truncateMaxTokens(input: string[]): string[] { truncateMaxTokens(input: string[]): string[] {
return input; return input;
} }
......
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import {
import { MetadataMode, TextNode } from "@llamaindex/core/schema"; BaseNode,
MetadataMode,
TextNode,
TransformComponent,
} from "@llamaindex/core/schema";
import { defaultNodeTextTemplate } from "./prompts.js"; import { defaultNodeTextTemplate } from "./prompts.js";
/* /*
* Abstract class for all extractors. * Abstract class for all extractors.
*/ */
export abstract class BaseExtractor implements TransformComponent { export abstract class BaseExtractor extends TransformComponent {
isTextNodeOnly: boolean = true; isTextNodeOnly: boolean = true;
showProgress: boolean = true; showProgress: boolean = true;
metadataMode: MetadataMode = MetadataMode.ALL; metadataMode: MetadataMode = MetadataMode.ALL;
...@@ -13,16 +17,18 @@ export abstract class BaseExtractor implements TransformComponent { ...@@ -13,16 +17,18 @@ export abstract class BaseExtractor implements TransformComponent {
inPlace: boolean = true; inPlace: boolean = true;
numWorkers: number = 4; numWorkers: number = 4;
abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>; constructor() {
super(async (nodes: BaseNode[], options?: any): Promise<BaseNode[]> => {
async transform(nodes: BaseNode[], options?: any): Promise<BaseNode[]> { return this.processNodes(
return this.processNodes( nodes,
nodes, options?.excludedEmbedMetadataKeys,
options?.excludedEmbedMetadataKeys, options?.excludedLlmMetadataKeys,
options?.excludedLlmMetadataKeys, );
); });
} }
abstract extract(nodes: BaseNode[]): Promise<Record<string, any>[]>;
/** /**
* *
* @param nodes Nodes to extract metadata from. * @param nodes Nodes to extract metadata from.
......
...@@ -172,7 +172,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -172,7 +172,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
const embedModel = const embedModel =
this.embedModel ?? this.vectorStores[type as ModalityType]?.embedModel; this.embedModel ?? this.vectorStores[type as ModalityType]?.embedModel;
if (embedModel && nodes) { if (embedModel && nodes) {
await embedModel.transform(nodes, { await embedModel(nodes, {
logProgress: options?.logProgress, logProgress: options?.logProgress,
}); });
} }
......
...@@ -35,7 +35,7 @@ export function getTransformationHash( ...@@ -35,7 +35,7 @@ export function getTransformationHash(
const transformString: string = transformToJSON(transform); const transformString: string = transformToJSON(transform);
const hash = createSHA256(); const hash = createSHA256();
hash.update(nodesStr + transformString); hash.update(nodesStr + transformString + transform.id);
return hash.digest(); return hash.digest();
} }
......
...@@ -40,7 +40,7 @@ export async function runTransformations( ...@@ -40,7 +40,7 @@ export async function runTransformations(
nodes = [...nodesToRun]; nodes = [...nodesToRun];
} }
if (docStoreStrategy) { if (docStoreStrategy) {
nodes = await docStoreStrategy.transform(nodes); nodes = await docStoreStrategy(nodes);
} }
for (const transform of transformations) { for (const transform of transformations) {
if (cache) { if (cache) {
...@@ -49,11 +49,11 @@ export async function runTransformations( ...@@ -49,11 +49,11 @@ export async function runTransformations(
if (cachedNodes) { if (cachedNodes) {
nodes = cachedNodes; nodes = cachedNodes;
} else { } else {
nodes = await transform.transform(nodes, transformOptions); nodes = await transform(nodes, transformOptions);
await cache.put(hash, nodes); await cache.put(hash, nodes);
} }
} else { } else {
nodes = await transform.transform(nodes, transformOptions); nodes = await transform(nodes, transformOptions);
} }
} }
return nodes; return nodes;
......
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { BaseDocumentStore } from "../../storage/docStore/types.js";
/** /**
* Handle doc store duplicates by checking all hashes. * Handle doc store duplicates by checking all hashes.
*/ */
export class DuplicatesStrategy implements TransformComponent { export class DuplicatesStrategy extends TransformComponent {
private docStore: BaseDocumentStore; private docStore: BaseDocumentStore;
constructor(docStore: BaseDocumentStore) { constructor(docStore: BaseDocumentStore) {
this.docStore = docStore; super(async (nodes: BaseNode[]): Promise<BaseNode[]> => {
} const hashes = await this.docStore.getAllDocumentHashes();
const currentHashes = new Set<string>();
const nodesToRun: BaseNode[] = [];
async transform(nodes: BaseNode[]): Promise<BaseNode[]> { for (const node of nodes) {
const hashes = await this.docStore.getAllDocumentHashes(); if (!(node.hash in hashes) && !currentHashes.has(node.hash)) {
const currentHashes = new Set<string>(); await this.docStore.setDocumentHash(node.id_, node.hash);
const nodesToRun: BaseNode[] = []; nodesToRun.push(node);
currentHashes.add(node.hash);
for (const node of nodes) { }
if (!(node.hash in hashes) && !currentHashes.has(node.hash)) {
await this.docStore.setDocumentHash(node.id_, node.hash);
nodesToRun.push(node);
currentHashes.add(node.hash);
} }
}
await this.docStore.addDocuments(nodesToRun, true); await this.docStore.addDocuments(nodesToRun, true);
return nodesToRun; return nodesToRun;
});
this.docStore = docStore;
} }
} }
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js";
import { classify } from "./classify.js"; import { classify } from "./classify.js";
...@@ -7,43 +7,42 @@ import { classify } from "./classify.js"; ...@@ -7,43 +7,42 @@ import { classify } from "./classify.js";
* Handle docstore upserts by checking hashes and ids. * Handle docstore upserts by checking hashes and ids.
* Identify missing docs and delete them from docstore and vector store * Identify missing docs and delete them from docstore and vector store
*/ */
export class UpsertsAndDeleteStrategy implements TransformComponent { export class UpsertsAndDeleteStrategy extends TransformComponent {
protected docStore: BaseDocumentStore; protected docStore: BaseDocumentStore;
protected vectorStores?: VectorStore[]; protected vectorStores?: VectorStore[];
constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) {
this.docStore = docStore; super(async (nodes: BaseNode[]): Promise<BaseNode[]> => {
this.vectorStores = vectorStores; const { dedupedNodes, missingDocs, unusedDocs } = await classify(
} this.docStore,
nodes,
);
async transform(nodes: BaseNode[]): Promise<BaseNode[]> { // remove unused docs
const { dedupedNodes, missingDocs, unusedDocs } = await classify( for (const refDocId of unusedDocs) {
this.docStore, await this.docStore.deleteRefDoc(refDocId, false);
nodes, if (this.vectorStores) {
); for (const vectorStore of this.vectorStores) {
await vectorStore.delete(refDocId);
// remove unused docs }
for (const refDocId of unusedDocs) {
await this.docStore.deleteRefDoc(refDocId, false);
if (this.vectorStores) {
for (const vectorStore of this.vectorStores) {
await vectorStore.delete(refDocId);
} }
} }
}
// remove missing docs // remove missing docs
for (const docId of missingDocs) { for (const docId of missingDocs) {
await this.docStore.deleteDocument(docId, true); await this.docStore.deleteDocument(docId, true);
if (this.vectorStores) { if (this.vectorStores) {
for (const vectorStore of this.vectorStores) { for (const vectorStore of this.vectorStores) {
await vectorStore.delete(docId); await vectorStore.delete(docId);
}
} }
} }
}
await this.docStore.addDocuments(dedupedNodes, true); await this.docStore.addDocuments(dedupedNodes, true);
return dedupedNodes; return dedupedNodes;
});
this.docStore = docStore;
this.vectorStores = vectorStores;
} }
} }
import type { BaseNode, TransformComponent } from "@llamaindex/core/schema"; import { BaseNode, TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js";
import { classify } from "./classify.js"; import { classify } from "./classify.js";
...@@ -6,28 +6,27 @@ import { classify } from "./classify.js"; ...@@ -6,28 +6,27 @@ import { classify } from "./classify.js";
/** /**
* Handles doc store upserts by checking hashes and ids. * Handles doc store upserts by checking hashes and ids.
*/ */
export class UpsertsStrategy implements TransformComponent { export class UpsertsStrategy extends TransformComponent {
protected docStore: BaseDocumentStore; protected docStore: BaseDocumentStore;
protected vectorStores?: VectorStore[]; protected vectorStores?: VectorStore[];
constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) { constructor(docStore: BaseDocumentStore, vectorStores?: VectorStore[]) {
this.docStore = docStore; super(async (nodes: BaseNode[]): Promise<BaseNode[]> => {
this.vectorStores = vectorStores; const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes);
} // remove unused docs
for (const refDocId of unusedDocs) {
async transform(nodes: BaseNode[]): Promise<BaseNode[]> { await this.docStore.deleteRefDoc(refDocId, false);
const { dedupedNodes, unusedDocs } = await classify(this.docStore, nodes); if (this.vectorStores) {
// remove unused docs for (const vectorStore of this.vectorStores) {
for (const refDocId of unusedDocs) { await vectorStore.delete(refDocId);
await this.docStore.deleteRefDoc(refDocId, false); }
if (this.vectorStores) {
for (const vectorStore of this.vectorStores) {
await vectorStore.delete(refDocId);
} }
} }
} // add non-duplicate docs
// add non-duplicate docs await this.docStore.addDocuments(dedupedNodes, true);
await this.docStore.addDocuments(dedupedNodes, true); return dedupedNodes;
return dedupedNodes; });
this.docStore = docStore;
this.vectorStores = vectorStores;
} }
} }
import type { TransformComponent } from "@llamaindex/core/schema"; import { TransformComponent } from "@llamaindex/core/schema";
import type { BaseDocumentStore } from "../../storage/docStore/types.js"; import type { BaseDocumentStore } from "../../storage/docStore/types.js";
import type { VectorStore } from "../../storage/vectorStore/types.js"; import type { VectorStore } from "../../storage/vectorStore/types.js";
import { DuplicatesStrategy } from "./DuplicatesStrategy.js"; import { DuplicatesStrategy } from "./DuplicatesStrategy.js";
...@@ -19,9 +19,9 @@ export enum DocStoreStrategy { ...@@ -19,9 +19,9 @@ export enum DocStoreStrategy {
NONE = "none", // no-op strategy NONE = "none", // no-op strategy
} }
class NoOpStrategy implements TransformComponent { class NoOpStrategy extends TransformComponent {
async transform(nodes: any[]): Promise<any[]> { constructor() {
return nodes; super(async (nodes) => nodes);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment