diff --git a/examples/pipeline/ingestion.ts b/examples/pipeline/ingestion.ts index 7446b08142295317993b5ef6ca44b159833062a7..341bfea778f6521a3e481a8eb23365167f7d4400 100644 --- a/examples/pipeline/ingestion.ts +++ b/examples/pipeline/ingestion.ts @@ -19,7 +19,6 @@ async function main() { const pipeline = new IngestionPipeline({ transformations: [ new SimpleNodeParser({ chunkSize: 1024, chunkOverlap: 20 }), - // new TitleExtractor(llm), new OpenAIEmbedding(), ], }); diff --git a/packages/core/src/ingestion/IngestionCache.ts b/packages/core/src/ingestion/IngestionCache.ts new file mode 100644 index 0000000000000000000000000000000000000000..e88906f8b8f89ca163887931e9d475653cccfcf2 --- /dev/null +++ b/packages/core/src/ingestion/IngestionCache.ts @@ -0,0 +1,47 @@ +import { BaseNode, MetadataMode } from "../Node"; +import { createSHA256 } from "../env"; +import { BaseKVStore, SimpleKVStore } from "../storage"; +import { docToJson, jsonToDoc } from "../storage/docStore/utils"; +import { TransformComponent } from "./types"; + +export function getTransformationHash( + nodes: BaseNode[], + transform: TransformComponent, +) { + const nodesStr: string = nodes + .map((node) => node.getContent(MetadataMode.ALL)) + .join(""); + + const transformString: string = JSON.stringify(transform); + const hash = createSHA256(); + hash.update(nodesStr + transformString); + return hash.digest(); +} + +export class IngestionCache { + collection: string = "llama_cache"; + cache: BaseKVStore; + nodesKey = "nodes"; + + constructor(collection?: string) { + if (collection) { + this.collection = collection; + } + this.cache = new SimpleKVStore(); + } + + async put(hash: string, nodes: BaseNode[]) { + const val = { + [this.nodesKey]: nodes.map((node) => docToJson(node)), + }; + await this.cache.put(hash, val, this.collection); + } + + async get(hash: string): Promise<BaseNode[] | undefined> { + const json = await this.cache.get(hash, this.collection); + if (!json || !json[this.nodesKey] || !Array.isArray(json[this.nodesKey])) { + return undefined; + } + return json[this.nodesKey].map((doc: any) => jsonToDoc(doc)); + } +} diff --git a/packages/core/src/ingestion/IngestionPipeline.ts b/packages/core/src/ingestion/IngestionPipeline.ts index 88068f0f07c5853975cf2d7b40fff1f3f0dd5ac8..377b7b67018b100c4a15666cd3394299002d9a91 100644 --- a/packages/core/src/ingestion/IngestionPipeline.ts +++ b/packages/core/src/ingestion/IngestionPipeline.ts @@ -1,32 +1,47 @@ import { BaseNode, Document } from "../Node"; import { BaseReader } from "../readers/base"; import { BaseDocumentStore, VectorStore } from "../storage"; +import { IngestionCache, getTransformationHash } from "./IngestionCache"; import { DocStoreStrategy, createDocStoreStrategy } from "./strategies"; import { TransformComponent } from "./types"; -interface IngestionRunArgs { +type IngestionRunArgs = { documents?: Document[]; nodes?: BaseNode[]; +}; + +type TransformRunArgs = { inPlace?: boolean; -} + cache?: IngestionCache; +}; export async function runTransformations( nodesToRun: BaseNode[], transformations: TransformComponent[], transformOptions: any = {}, - { inPlace = true }: IngestionRunArgs, + { inPlace = true, cache }: TransformRunArgs, ): Promise<BaseNode[]> { let nodes = nodesToRun; if (!inPlace) { nodes = [...nodesToRun]; } for (const transform of transformations) { - nodes = await transform.transform(nodes, transformOptions); + if (cache) { + const hash = getTransformationHash(nodes, transform); + const cachedNodes = await cache.get(hash); + if (cachedNodes) { + nodes = cachedNodes; + } else { + nodes = await transform.transform(nodes, transformOptions); + await cache.put(hash, nodes); + } + } else { + nodes = await transform.transform(nodes, transformOptions); + } } return nodes; } -// TODO: add caching, add concurrency export class IngestionPipeline { transformations: TransformComponent[] = []; documents?: Document[]; @@ -34,7 +49,8 @@ export class IngestionPipeline { vectorStore?: VectorStore; docStore?: BaseDocumentStore; docStoreStrategy: DocStoreStrategy = DocStoreStrategy.UPSERTS; - disableCache: boolean = true; + cache?: IngestionCache; + disableCache: boolean = false; private _docStoreStrategy?: TransformComponent; @@ -45,6 +61,9 @@ export class IngestionPipeline { this.docStore, this.vectorStore, ); + if (!this.disableCache) { + this.cache = new IngestionCache(); + } } async prepareInput( @@ -68,9 +87,10 @@ export class IngestionPipeline { } async run( - args: IngestionRunArgs = {}, + args: IngestionRunArgs & TransformRunArgs = {}, transformOptions?: any, ): Promise<BaseNode[]> { + args.cache = args.cache ?? this.cache; const inputNodes = await this.prepareInput(args.documents, args.nodes); let nodesToRun; if (this._docStoreStrategy) { diff --git a/packages/core/src/tests/ingestion/IngestionCache.test.ts b/packages/core/src/tests/ingestion/IngestionCache.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..52a27801eeb23cb79a43fb91e7aaf06b3d5bf591 --- /dev/null +++ b/packages/core/src/tests/ingestion/IngestionCache.test.ts @@ -0,0 +1,74 @@ +import { BaseNode, TextNode } from "../../Node"; +import { TransformComponent } from "../../ingestion"; +import { + IngestionCache, + getTransformationHash, +} from "../../ingestion/IngestionCache"; +import { SimpleNodeParser } from "../../nodeParsers"; + +describe("IngestionCache", () => { + let cache: IngestionCache; + const hash = "1"; + + beforeAll(() => { + cache = new IngestionCache(); + }); + test("should put and get", async () => { + const nodes = [new TextNode({ text: "some text", id_: "some id" })]; + await cache.put(hash, nodes); + const result = await cache.get(hash); + expect(result).toEqual(nodes); + }); + test("should return undefined if not found", async () => { + const result = await cache.get("not found"); + expect(result).toBeUndefined(); + }); +}); + +describe("getTransformationHash", () => { + let nodes: BaseNode[], transform: TransformComponent; + + beforeAll(() => { + nodes = [new TextNode({ text: "some text", id_: "some id" })]; + transform = new SimpleNodeParser({ + chunkOverlap: 10, + chunkSize: 1024, + }); + }); + test("should return a hash", () => { + const result = getTransformationHash(nodes, transform); + expect(typeof result).toBe("string"); + }); + test("should return the same hash for the same inputs", () => { + const result1 = getTransformationHash(nodes, transform); + const result2 = getTransformationHash(nodes, transform); + expect(result1).toBe(result2); + }); + test("should return the same hash for other instances with same inputs", () => { + const result1 = getTransformationHash( + [new TextNode({ text: "some text", id_: "some id" })], + transform, + ); + const result2 = getTransformationHash(nodes, transform); + expect(result1).toBe(result2); + }); + test("should return different hashes for different nodes", () => { + const result1 = getTransformationHash(nodes, transform); + const result2 = getTransformationHash( + [new TextNode({ text: "some other text", id_: "some id" })], + transform, + ); + expect(result1).not.toBe(result2); + }); + test("should return different hashes for different transforms", () => { + const result1 = getTransformationHash(nodes, transform); + const result2 = getTransformationHash( + nodes, + new SimpleNodeParser({ + chunkOverlap: 10, + chunkSize: 512, + }), + ); + expect(result1).not.toBe(result2); + }); +});