diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index 43344097b747dc08d44c572d83cb6f7230caf626..1908e67248a6945774b2e0728583ecd3ac676f8b 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -23,19 +23,23 @@ export enum MetadataMode { NONE = "NONE", } -export interface RelatedNodeInfo { +export type Metadata = Record<string, any>; + +export interface RelatedNodeInfo<T extends Metadata = Metadata> { nodeId: string; nodeType?: ObjectType; - metadata: Record<string, any>; + metadata: T; hash?: string; } -export type RelatedNodeType = RelatedNodeInfo | RelatedNodeInfo[]; +export type RelatedNodeType<T extends Metadata = Metadata> = + | RelatedNodeInfo<T> + | RelatedNodeInfo<T>[]; /** * Generic abstract class for retrievable nodes */ -export abstract class BaseNode { +export abstract class BaseNode<T extends Metadata = Metadata> { /** * The unique ID of the Node/Document. The trailing underscore is here * to avoid collisions with the id keyword in Python. @@ -46,13 +50,13 @@ export abstract class BaseNode { embedding?: number[]; // Metadata fields - metadata: Record<string, any> = {}; + metadata: T = {} as T; excludedEmbedMetadataKeys: string[] = []; excludedLlmMetadataKeys: string[] = []; - relationships: Partial<Record<NodeRelationship, RelatedNodeType>> = {}; + relationships: Partial<Record<NodeRelationship, RelatedNodeType<T>>> = {}; hash: string = ""; - constructor(init?: Partial<BaseNode>) { + constructor(init?: Partial<BaseNode<T>>) { Object.assign(this, init); } @@ -62,7 +66,7 @@ export abstract class BaseNode { abstract getMetadataStr(metadataMode: MetadataMode): string; abstract setContent(value: any): void; - get sourceNode(): RelatedNodeInfo | undefined { + get sourceNode(): RelatedNodeInfo<T> | undefined { const relationship = this.relationships[NodeRelationship.SOURCE]; if (Array.isArray(relationship)) { @@ -72,7 +76,7 @@ export abstract class BaseNode { return relationship; } - get prevNode(): RelatedNodeInfo | undefined { + get prevNode(): RelatedNodeInfo<T> | undefined { const relationship = this.relationships[NodeRelationship.PREVIOUS]; if (Array.isArray(relationship)) { @@ -84,7 +88,7 @@ export abstract class BaseNode { return relationship; } - get nextNode(): RelatedNodeInfo | undefined { + get nextNode(): RelatedNodeInfo<T> | undefined { const relationship = this.relationships[NodeRelationship.NEXT]; if (Array.isArray(relationship)) { @@ -94,7 +98,7 @@ export abstract class BaseNode { return relationship; } - get parentNode(): RelatedNodeInfo | undefined { + get parentNode(): RelatedNodeInfo<T> | undefined { const relationship = this.relationships[NodeRelationship.PARENT]; if (Array.isArray(relationship)) { @@ -104,7 +108,7 @@ export abstract class BaseNode { return relationship; } - get childNodes(): RelatedNodeInfo[] | undefined { + get childNodes(): RelatedNodeInfo<T>[] | undefined { const relationship = this.relationships[NodeRelationship.CHILD]; if (!Array.isArray(relationship)) { @@ -126,7 +130,7 @@ export abstract class BaseNode { return this.embedding; } - asRelatedNodeInfo(): RelatedNodeInfo { + asRelatedNodeInfo(): RelatedNodeInfo<T> { return { nodeId: this.id_, metadata: this.metadata, @@ -146,7 +150,7 @@ export abstract class BaseNode { /** * TextNode is the default node type for text. Most common node type in LlamaIndex.TS */ -export class TextNode extends BaseNode { +export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> { text: string = ""; startCharIdx?: number; endCharIdx?: number; @@ -154,7 +158,7 @@ export class TextNode extends BaseNode { // metadataTemplate: NOTE write your own formatter if needed metadataSeparator: string = "\n"; - constructor(init?: Partial<TextNode>) { + constructor(init?: Partial<TextNode<T>>) { super(init); Object.assign(this, init); @@ -233,10 +237,10 @@ export class TextNode extends BaseNode { // } // } -export class IndexNode extends TextNode { +export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> { indexId: string = ""; - constructor(init?: Partial<IndexNode>) { + constructor(init?: Partial<IndexNode<T>>) { super(init); Object.assign(this, init); @@ -253,8 +257,8 @@ export class IndexNode extends TextNode { /** * A document is just a special text node with a docId. */ -export class Document extends TextNode { - constructor(init?: Partial<Document>) { +export class Document<T extends Metadata = Metadata> extends TextNode<T> { + constructor(init?: Partial<Document<T>>) { super(init); Object.assign(this, init); @@ -292,7 +296,7 @@ export function jsonToNode(json: any) { /** * A node with a similarity score */ -export interface NodeWithScore { - node: BaseNode; +export interface NodeWithScore<T extends Metadata = Metadata> { + node: BaseNode<T>; score?: number; }