diff --git a/.changeset/selfish-kids-own.md b/.changeset/selfish-kids-own.md new file mode 100644 index 0000000000000000000000000000000000000000..f9fb2aa9a4608c27ba7f7c2654fb7cc665418243 --- /dev/null +++ b/.changeset/selfish-kids-own.md @@ -0,0 +1,7 @@ +--- +"@llamaindex/core": patch +"llamaindex": patch +"@llamaindex/core-tests": patch +--- + +feat(node-parser): support async function diff --git a/packages/core/src/node-parser/base.ts b/packages/core/src/node-parser/base.ts index dbf0c0bc5f398c34e27d830e41862c4f7bfea429..ec5631e038bcb6defd3f4741933193c6328e41c5 100644 --- a/packages/core/src/node-parser/base.ts +++ b/packages/core/src/node-parser/base.ts @@ -7,21 +7,27 @@ import { TextNode, TransformComponent, } from "../schema"; +import { isPromise } from "../utils"; -export abstract class NodeParser extends TransformComponent<BaseNode[]> { +export abstract class NodeParser< + Result extends TextNode[] | Promise<TextNode[]> = + | TextNode[] + | Promise<TextNode[]>, +> extends TransformComponent<Result> { includeMetadata: boolean = true; includePrevNextRel: boolean = true; constructor() { - super((nodes: BaseNode[]): BaseNode[] => { + super((nodes: BaseNode[]): Result => { + // alex: should we fix `as` type? return this.getNodesFromDocuments(nodes as TextNode[]); }); } protected postProcessParsedNodes( - nodes: TextNode[], + nodes: Awaited<Result>, parentDocMap: Map<string, TextNode>, - ): TextNode[] { + ): Awaited<Result> { nodes.forEach((node, i) => { const parentDoc = parentDocMap.get(node.sourceNode?.nodeId || ""); @@ -73,9 +79,9 @@ export abstract class NodeParser extends TransformComponent<BaseNode[]> { protected abstract parseNodes( documents: TextNode[], showProgress?: boolean, - ): TextNode[]; + ): Result; - public getNodesFromDocuments(documents: TextNode[]): TextNode[] { + public getNodesFromDocuments(documents: TextNode[]): Result { const docsId: Map<string, TextNode> = new Map( documents.map((doc) => [doc.id_, doc]), ); @@ -85,20 +91,36 @@ export abstract class NodeParser extends TransformComponent<BaseNode[]> { documents, }); - const nodes = this.postProcessParsedNodes( - this.parseNodes(documents), - docsId, - ); + const parsedNodes = this.parseNodes(documents); + if (isPromise(parsedNodes)) { + return parsedNodes.then((parsedNodes) => { + const nodes = this.postProcessParsedNodes( + parsedNodes as Awaited<Result>, + docsId, + ); - callbackManager.dispatchEvent("node-parsing-end", { - nodes, - }); + callbackManager.dispatchEvent("node-parsing-end", { + nodes, + }); - return nodes; + return nodes; + }) as Result; + } else { + const nodes = this.postProcessParsedNodes( + parsedNodes as Awaited<Result>, + docsId, + ); + + callbackManager.dispatchEvent("node-parsing-end", { + nodes, + }); + + return nodes; + } } } -export abstract class TextSplitter extends NodeParser { +export abstract class TextSplitter extends NodeParser<TextNode[]> { abstract splitText(text: string): string[]; public splitTexts(texts: string[]): string[] { diff --git a/packages/core/src/node-parser/markdown.ts b/packages/core/src/node-parser/markdown.ts index 12c08294438d98d6259c5af3ff4bb7eb1357e81d..e7c14f875248b1b9861da28fbeaebb879fa8fea5 100644 --- a/packages/core/src/node-parser/markdown.ts +++ b/packages/core/src/node-parser/markdown.ts @@ -6,7 +6,7 @@ import { } from "../schema"; import { NodeParser } from "./base"; -export class MarkdownNodeParser extends NodeParser { +export class MarkdownNodeParser extends NodeParser<TextNode[]> { override parseNodes(nodes: TextNode[], showProgress?: boolean): TextNode[] { return nodes.reduce<TextNode[]>((allNodes, node) => { const markdownNodes = this.getNodesFromNode(node); diff --git a/packages/core/src/node-parser/sentence-window.ts b/packages/core/src/node-parser/sentence-window.ts index be6219a774ab062ddfbd93de36e0d8abaed5ca17..161689371200ea9c1e0895111840dfb2ac9bafc4 100644 --- a/packages/core/src/node-parser/sentence-window.ts +++ b/packages/core/src/node-parser/sentence-window.ts @@ -9,7 +9,7 @@ import { import { NodeParser } from "./base"; import { splitBySentenceTokenizer, type TextSplitterFn } from "./utils"; -export class SentenceWindowNodeParser extends NodeParser { +export class SentenceWindowNodeParser extends NodeParser<TextNode[]> { static DEFAULT_WINDOW_SIZE = 3; static DEFAULT_WINDOW_METADATA_KEY = "window"; static DEFAULT_ORIGINAL_TEXT_METADATA_KEY = "originalText"; diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index a682739194aab5988c04c16d55f228e823b219bc..920a6787a67e4de8f2698298c2d0b4ca6ac17c14 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -1,5 +1,9 @@ import type { JSONValue } from "../global"; +export const isPromise = <T>(obj: unknown): obj is Promise<T> => { + return obj != null && typeof obj === "object" && "then" in obj; +}; + export const isAsyncIterable = ( obj: unknown, ): obj is AsyncIterable<unknown> => { diff --git a/packages/core/tests/node-parser/node-parser.test.ts b/packages/core/tests/node-parser/node-parser.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..6357173ce880c978104159d0ff6899eee346a130 --- /dev/null +++ b/packages/core/tests/node-parser/node-parser.test.ts @@ -0,0 +1,24 @@ +import { NodeParser } from "@llamaindex/core/node-parser"; +import { TextNode } from "@llamaindex/core/schema"; +import { describe, expect, test } from "vitest"; + +describe("NodeParser", () => { + test("node parser should allow async parse function", async () => { + class MyNodeParser extends NodeParser<Promise<TextNode[]>> { + protected async parseNodes(documents: TextNode[]): Promise<TextNode[]> { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return documents; + } + } + + const nodeParser = new MyNodeParser(); + const nodes = [ + new TextNode({ + text: "Hello, world!", + }), + ]; + const result = nodeParser(nodes); + expect(result).toBeInstanceOf(Promise); + await expect(result).resolves.toEqual(nodes); + }); +}); diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index ca79b0376126fe485bbf952534eae8b2506c6825..827874901b759da012c5e67fdf86154489019917 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -296,7 +296,7 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { await docStore.setDocumentHash(doc.id_, doc.hash); } - const nodes = Settings.nodeParser.getNodesFromDocuments(documents); + const nodes = await Settings.nodeParser.getNodesFromDocuments(documents); const index = await KeywordTableIndex.init({ nodes, storageContext, diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index f1e955266c0fa45dd65693df2543978b34358e1b..a8cc772c59915b3bd2825e88c630d66edb91ff14 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -145,7 +145,7 @@ export class SummaryIndex extends BaseIndex<IndexList> { await docStore.setDocumentHash(doc.id_, doc.hash); } - const nodes = Settings.nodeParser.getNodesFromDocuments(documents); + const nodes = await Settings.nodeParser.getNodesFromDocuments(documents); const index = await SummaryIndex.init({ nodes,