diff --git a/packages/core/src/nodeParsers/SentenceWindowNodeParser.ts b/packages/core/src/nodeParsers/SentenceWindowNodeParser.ts new file mode 100644 index 0000000000000000000000000000000000000000..46794bd181367f9462a3cd75ec5a7de9874ba5d7 --- /dev/null +++ b/packages/core/src/nodeParsers/SentenceWindowNodeParser.ts @@ -0,0 +1,85 @@ +import { BaseNode } from "../Node"; +import { SentenceSplitter } from "../TextSplitter"; +import { NodeParser } from "./types"; +import { getNodesFromDocument } from "./utils"; + +export const DEFAULT_WINDOW_SIZE = 3; +export const DEFAULT_WINDOW_METADATA_KEY = "window"; +export const DEFAULT_OG_TEXT_METADATA_KEY = "original_text"; + +export class SentenceWindowNodeParser implements NodeParser { + /** + * The text splitter to use. + */ + textSplitter: SentenceSplitter; + /** + * The number of sentences on each side of a sentence to capture. + */ + windowSize: number = DEFAULT_WINDOW_SIZE; + /** + * The metadata key to store the sentence window under. + */ + windowMetadataKey: string = DEFAULT_WINDOW_METADATA_KEY; + /** + * The metadata key to store the original sentence in. + */ + originalTextMetadataKey: string = DEFAULT_OG_TEXT_METADATA_KEY; + /** + * Whether to include metadata in the nodes. + */ + includeMetadata: boolean = true; + /** + * Whether to include previous and next relationships in the nodes. + */ + includePrevNextRel: boolean = true; + + constructor(init?: Partial<SentenceWindowNodeParser>) { + Object.assign(this, init); + this.textSplitter = init?.textSplitter ?? new SentenceSplitter(); + } + + static fromDefaults( + init?: Partial<SentenceWindowNodeParser>, + ): SentenceWindowNodeParser { + return new SentenceWindowNodeParser(init); + } + + getNodesFromDocuments(documents: BaseNode[]) { + return documents + .map((document) => this.buildWindowNodesFromDocument(document)) + .flat(); + } + + protected buildWindowNodesFromDocument(doc: BaseNode): BaseNode[] { + const nodes = getNodesFromDocument( + doc, + this.textSplitter.getSentenceSplits.bind(this.textSplitter), + this.includeMetadata, + this.includePrevNextRel, + ); + + for (let i = 0; i < nodes.length; i++) { + const node = nodes[i]; + const windowNodes = nodes.slice( + Math.max(0, i - this.windowSize), + Math.min(i + this.windowSize + 1, nodes.length), + ); + + node.metadata[this.windowMetadataKey] = windowNodes + .map((n) => n.getText()) + .join(" "); + node.metadata[this.originalTextMetadataKey] = node.getText(); + + node.excludedEmbedMetadataKeys.push( + this.windowMetadataKey, + this.originalTextMetadataKey, + ); + node.excludedLlmMetadataKeys.push( + this.windowMetadataKey, + this.originalTextMetadataKey, + ); + } + + return nodes; + } +} diff --git a/packages/core/src/nodeParsers/SimpleNodeParser.ts b/packages/core/src/nodeParsers/SimpleNodeParser.ts index 8a1738f571ceae6ce652cd23d31e262407b3ea63..7a95000b82c7be9819564c5ec998e6d07b553f13 100644 --- a/packages/core/src/nodeParsers/SimpleNodeParser.ts +++ b/packages/core/src/nodeParsers/SimpleNodeParser.ts @@ -1,77 +1,8 @@ -import { - BaseNode, - Document, - ImageDocument, - NodeRelationship, - TextNode, -} from "../Node"; +import { BaseNode } from "../Node"; import { SentenceSplitter } from "../TextSplitter"; import { DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE } from "../constants"; import { NodeParser } from "./types"; - -/** - * Splits the text of a document into smaller parts. - * @param document - The document to split. - * @param textSplitter - The text splitter to use. - * @returns An array of text splits. - */ -function getTextSplitsFromDocument( - document: Document, - textSplitter: SentenceSplitter, -) { - const text = document.getText(); - return textSplitter.splitText(text); -} - -/** - * Generates an array of nodes from a document. - * @param doc - * @param textSplitter - The text splitter to use. - * @param includeMetadata - Whether to include metadata in the nodes. - * @param includePrevNextRel - Whether to include previous and next relationships in the nodes. - * @returns An array of nodes. - */ -function getNodesFromDocument( - doc: BaseNode, - textSplitter: SentenceSplitter, - includeMetadata: boolean = true, - includePrevNextRel: boolean = true, -) { - if (doc instanceof ImageDocument) { - return [doc]; - } - if (!(doc instanceof Document)) { - throw new Error("Expected either an Image Document or Document"); - } - const document = doc as Document; - const nodes: TextNode[] = []; - - const textSplits = getTextSplitsFromDocument(document, textSplitter); - - textSplits.forEach((textSplit) => { - const node = new TextNode({ - text: textSplit, - metadata: includeMetadata ? document.metadata : {}, - }); - node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo(); - nodes.push(node); - }); - - if (includePrevNextRel) { - nodes.forEach((node, index) => { - if (index > 0) { - node.relationships[NodeRelationship.PREVIOUS] = - nodes[index - 1].asRelatedNodeInfo(); - } - if (index < nodes.length - 1) { - node.relationships[NodeRelationship.NEXT] = - nodes[index + 1].asRelatedNodeInfo(); - } - }); - } - - return nodes; -} +import { getNodesFromDocument } from "./utils"; /** * SimpleNodeParser is the default NodeParser. It splits documents into TextNodes using a splitter, by default SentenceSplitter @@ -94,7 +25,6 @@ export class SimpleNodeParser implements NodeParser { textSplitter?: SentenceSplitter; includeMetadata?: boolean; includePrevNextRel?: boolean; - chunkSize?: number; chunkOverlap?: number; }) { @@ -123,7 +53,14 @@ export class SimpleNodeParser implements NodeParser { */ getNodesFromDocuments(documents: BaseNode[]) { return documents - .map((document) => getNodesFromDocument(document, this.textSplitter)) + .map((document) => + getNodesFromDocument( + document, + this.textSplitter.splitText.bind(this.textSplitter), + this.includeMetadata, + this.includePrevNextRel, + ), + ) .flat(); } } diff --git a/packages/core/src/nodeParsers/index.ts b/packages/core/src/nodeParsers/index.ts index 4094a976711a71da70cbd7cac74ae897c3b14893..0507f22c1dbc8c9294194e97d8849cf4514fa129 100644 --- a/packages/core/src/nodeParsers/index.ts +++ b/packages/core/src/nodeParsers/index.ts @@ -1,2 +1,3 @@ +export * from "./SentenceWindowNodeParser"; export * from "./SimpleNodeParser"; export * from "./types"; diff --git a/packages/core/src/nodeParsers/utils.ts b/packages/core/src/nodeParsers/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..96765288aef81fabb6900b77fa5196e52f0bf6b7 --- /dev/null +++ b/packages/core/src/nodeParsers/utils.ts @@ -0,0 +1,75 @@ +import _ from "lodash"; +import { + BaseNode, + Document, + ImageDocument, + NodeRelationship, + TextNode, +} from "../Node"; + +type TextSplitter = (s: string) => string[]; + +/** + * Splits the text of a document into smaller parts. + * @param document - The document to split. + * @param textSplitter - The text splitter to use. + * @returns An array of text splits. + */ +function getTextSplitsFromDocument( + document: Document, + textSplitter: TextSplitter, +) { + const text = document.getText(); + return textSplitter(text); +} + +/** + * Generates an array of nodes from a document. + * @param doc + * @param textSplitter - The text splitter to use. + * @param includeMetadata - Whether to include metadata in the nodes. + * @param includePrevNextRel - Whether to include previous and next relationships in the nodes. + * @returns An array of nodes. + */ +export function getNodesFromDocument( + doc: BaseNode, + textSplitter: TextSplitter, + includeMetadata: boolean = true, + includePrevNextRel: boolean = true, +): TextNode[] { + if (doc instanceof ImageDocument) { + // TODO: use text splitter on text of image documents + return [doc]; + } + if (!(doc instanceof Document)) { + throw new Error("Expected either an Image Document or Document"); + } + const document = doc as Document; + const nodes: TextNode[] = []; + + const textSplits = getTextSplitsFromDocument(document, textSplitter); + + textSplits.forEach((textSplit) => { + const node = new TextNode({ + text: textSplit, + metadata: includeMetadata ? _.cloneDeep(document.metadata) : {}, + }); + node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo(); + nodes.push(node); + }); + + if (includePrevNextRel) { + nodes.forEach((node, index) => { + if (index > 0) { + node.relationships[NodeRelationship.PREVIOUS] = + nodes[index - 1].asRelatedNodeInfo(); + } + if (index < nodes.length - 1) { + node.relationships[NodeRelationship.NEXT] = + nodes[index + 1].asRelatedNodeInfo(); + } + }); + } + + return nodes; +} diff --git a/packages/core/src/tests/nodeParsers/SentenceWindowNodeParser.test.ts b/packages/core/src/tests/nodeParsers/SentenceWindowNodeParser.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..143a01932752c1d126a1187a80c37159689baaa2 --- /dev/null +++ b/packages/core/src/tests/nodeParsers/SentenceWindowNodeParser.test.ts @@ -0,0 +1,34 @@ +import { Document, MetadataMode } from "../../Node"; +import { + DEFAULT_WINDOW_METADATA_KEY, + SentenceWindowNodeParser, +} from "../../nodeParsers"; + +describe("Tests for the SentenceWindowNodeParser class", () => { + test("testing the constructor", () => { + const sentenceWindowNodeParser = new SentenceWindowNodeParser(); + expect(sentenceWindowNodeParser).toBeDefined(); + }); + test("testing the getNodesFromDocuments method", () => { + const sentenceWindowNodeParser = SentenceWindowNodeParser.fromDefaults({ + windowSize: 1, + }); + const doc = new Document({ text: "Hello. Cat Mouse. Dog." }); + const resultingNodes = sentenceWindowNodeParser.getNodesFromDocuments([ + doc, + ]); + expect(resultingNodes.length).toEqual(3); + expect(resultingNodes.map((n) => n.getContent(MetadataMode.NONE))).toEqual([ + "Hello.", + "Cat Mouse.", + "Dog.", + ]); + expect( + resultingNodes.map((n) => n.metadata[DEFAULT_WINDOW_METADATA_KEY]), + ).toEqual([ + "Hello. Cat Mouse.", + "Hello. Cat Mouse. Dog.", + "Cat Mouse. Dog.", + ]); + }); +});