diff --git a/apps/simple/index.ts b/apps/simple/index.ts index d7e804240f0158e99b27333501ed615e61fa4ca3..733bb7f07fedcf7c8b74309fe384b214f7cec6cc 100644 --- a/apps/simple/index.ts +++ b/apps/simple/index.ts @@ -1,9 +1,9 @@ -import { Document } from "@llamaindex/core/src/Document"; +import { Document } from "@llamaindex/core/src/Node"; import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; import essay from "./essay"; (async () => { - const document = new Document(essay); + const document = new Document({ text: essay }); const index = await VectorStoreIndex.fromDocuments([document]); const queryEngine = index.asQueryEngine(); const response = await queryEngine.aquery( diff --git a/apps/simple/lowlevel.ts b/apps/simple/lowlevel.ts new file mode 100644 index 0000000000000000000000000000000000000000..ebfdb076a51751183af42487e0570c6891f8bdab --- /dev/null +++ b/apps/simple/lowlevel.ts @@ -0,0 +1,31 @@ +import { Document, TextNode, NodeWithScore } from "@llamaindex/core/src/Node"; +import { ResponseSynthesizer } from "@llamaindex/core/src/ResponseSynthesizer"; +import { SimpleNodeParser } from "@llamaindex/core/src/NodeParser"; + +(async () => { + const nodeParser = new SimpleNodeParser(); + const nodes = nodeParser.getNodesFromDocuments([ + new Document({ text: "I am 10 years old. John is 20 years old." }), + ]); + + console.log(nodes); + + const responseSynthesizer = new ResponseSynthesizer(); + + const nodesWithScore: NodeWithScore[] = [ + { + node: new TextNode({ text: "I am 10 years old." }), + score: 1, + }, + { + node: new TextNode({ text: "John is 20 years old." }), + score: 0.5, + }, + ]; + + const response = await responseSynthesizer.asynthesize( + "What age am I?", + nodesWithScore + ); + console.log(response.response); +})(); diff --git a/packages/core/src/BaseIndex.ts b/packages/core/src/BaseIndex.ts index 9934e56071b23a91b65a536dabfcbe041a5c38a5..2b05ad588fa85c84c011db8302236987654279de 100644 --- a/packages/core/src/BaseIndex.ts +++ b/packages/core/src/BaseIndex.ts @@ -1,14 +1,13 @@ -import { Document } from "./Document"; -import { Node, NodeWithEmbedding } from "./Node"; +import { Document, TextNode } from "./Node"; import { SimpleNodeParser } from "./NodeParser"; import { BaseQueryEngine, RetrieverQueryEngine } from "./QueryEngine"; import { v4 as uuidv4 } from "uuid"; import { VectorIndexRetriever } from "./Retriever"; import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; export class BaseIndex { - nodes: Node[] = []; + nodes: TextNode[] = []; - constructor(nodes?: Node[]) { + constructor(nodes?: TextNode[]) { this.nodes = nodes ?? []; } } @@ -16,7 +15,7 @@ export class BaseIndex { export class IndexDict { indexId: string; summary?: string; - nodesDict: Record<string, Node> = {}; + nodesDict: Record<string, TextNode> = {}; docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext constructor(indexId = uuidv4(), summary = undefined) { @@ -31,18 +30,17 @@ export class IndexDict { return this.summary; } - addNode(node: Node, textId?: string) { - const vectorId = textId ?? node.getDocId(); + addNode(node: TextNode, textId?: string) { + const vectorId = textId ?? node.id_; this.nodesDict[vectorId] = node; } } export class VectorStoreIndex extends BaseIndex { indexStruct: IndexDict; - nodesWithEmbeddings: NodeWithEmbedding[] = []; // FIXME replace with storage context embeddingService: BaseEmbedding; // FIXME replace with service context - constructor(nodes: Node[]) { + constructor(nodes: TextNode[]) { super(nodes); this.indexStruct = new IndexDict(); @@ -62,7 +60,7 @@ export class VectorStoreIndex extends BaseIndex { const embedding = await this.embeddingService.aGetTextEmbedding( node.getText() ); - this.nodesWithEmbeddings.push({ node: node, embedding: embedding }); + node.embedding = embedding; } } diff --git a/packages/core/src/Document.ts b/packages/core/src/Document.ts deleted file mode 100644 index a8b9c8559bed7bc8a967ee5d6b96958010a5aa9d..0000000000000000000000000000000000000000 --- a/packages/core/src/Document.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { v4 as uuidv4 } from "uuid"; - -export enum NodeType { - DOCUMENT, - TEXT, - IMAGE, - INDEX, -} - -export abstract class BaseDocument { - text: string; - docId?: string; - embedding?: number[]; - docHash?: string; - - constructor( - text: string, - docId?: string, - embedding?: number[], - docHash?: string - ) { - this.text = text; - this.docId = docId; - this.embedding = embedding; - this.docHash = docHash; - - if (!docId) { - this.docId = uuidv4(); - } - } - - getText() { - if (this.text === undefined) { - throw new Error("Text not set"); - } - return this.text; - } - - getDocId() { - if (this.docId === undefined) { - throw new Error("doc id not set"); - } - return this.docId; - } - - getEmbedding() { - if (this.embedding === undefined) { - throw new Error("Embedding not set"); - } - return this.embedding; - } - - getDocHash() { - if (this.docHash === undefined) { - throw new Error("Doc hash not set"); - } - return this.docHash; - } - - abstract getType(): NodeType; -} - -export class Document extends BaseDocument { - getType() { - return NodeType.DOCUMENT; - } -} - -export class ImageDocument extends Document { - image?: string; - - getType() { - return NodeType.IMAGE; - } -} diff --git a/packages/core/src/LLMPredictor.ts b/packages/core/src/LLMPredictor.ts index ac54bc1aa5a3af8c6568cd6dd3dbc299d464f81e..d63e0bbaff8087814f998f02f261a1946deca0d9 100644 --- a/packages/core/src/LLMPredictor.ts +++ b/packages/core/src/LLMPredictor.ts @@ -5,7 +5,7 @@ export interface BaseLLMPredictor { getLlmMetadata(): Promise<any>; apredict( prompt: string | SimplePrompt, - input?: { [key: string]: string } + input?: Record<string, string> ): Promise<string>; // stream(prompt: string, options: any): Promise<any>; } @@ -31,7 +31,7 @@ export class ChatGPTLLMPredictor implements BaseLLMPredictor { async apredict( prompt: string | SimplePrompt, - input?: { [key: string]: string } + input?: Record<string, string> ): Promise<string> { if (typeof prompt === "string") { const result = await this.languageModel.agenerate([ diff --git a/packages/core/src/LanguageModel.ts b/packages/core/src/LanguageModel.ts index 8862e2fd905ad0f56e0ab45d8df95377b6c62dfe..c626ece8cf2449e14d8da202e6ef8828a1961368 100644 --- a/packages/core/src/LanguageModel.ts +++ b/packages/core/src/LanguageModel.ts @@ -17,7 +17,7 @@ interface BaseMessage { interface Generation { text: string; - generationInfo?: { [key: string]: any }; + generationInfo?: Record<string, any>; } export interface LLMResult { diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index f97b84a5f4d5d62d9fea767520a31995b5984f81..169413b5b377f31960ae65cd90cd46205cd2900d 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -1,72 +1,231 @@ -import { BaseDocument, NodeType } from "./Document"; - -export enum DocumentRelationship { - SOURCE = "source", - PREVIOUS = "previous", - NEXT = "next", - PARENT = "parent", - CHILD = "child", +import { v4 as uuidv4 } from "uuid"; + +export enum NodeRelationship { + SOURCE = "SOURCE", + PREVIOUS = "PREVIOUS", + NEXT = "NEXT", + PARENT = "PARENT", + CHILD = "CHILD", +} + +export enum ObjectType { + TEXT = "TEXT", + IMAGE = "IMAGE", + INDEX = "INDEX", + DOCUMENT = "DOCUMENT", +} + +export enum MetadataMode { + ALL = "ALL", + EMBED = "EMBED", + LLM = "LLM", + NONE = "NONE", } -export class Node extends BaseDocument { - relationships: { [key in DocumentRelationship]: string | string[] | null }; - - constructor( - text: string, // Text is required - docId?: string, - embedding?: number[], - docHash?: string - ) { - if (text === undefined) { - throw new Error("Text is required"); +export interface RelatedNodeInfo { + nodeId: string; + nodeType?: ObjectType; + metadata: Record<string, any>; + hash?: string; +} + +export type RelatedNodeType = RelatedNodeInfo | RelatedNodeInfo[]; + +/** + * Generic abstract class for retrievable nodes + */ +export abstract class BaseNode { + id_: string = uuidv4(); + embedding?: number[]; + + // Metadata fields + metadata: Record<string, any> = {}; + excludedEmbedMetadataKeys: string[] = []; + excludedLlmMetadataKeys: string[] = []; + relationships: Partial<Record<NodeRelationship, RelatedNodeType>> = {}; + hash: string = ""; + + constructor(init?: Partial<BaseNode>) { + Object.assign(this, init); + } + + abstract getType(): ObjectType; + + abstract getContent(metadataMode: MetadataMode): string; + abstract getMetadataStr(metadataMode: MetadataMode): string; + abstract setContent(value: any): void; + + get nodeId(): string { + return this.id_; + } + + get sourceNode(): RelatedNodeInfo | undefined { + const relationship = this.relationships[NodeRelationship.SOURCE]; + + if (Array.isArray(relationship)) { + throw new Error("Source object must be a single RelatedNodeInfo object"); } - super(text, docId, embedding, docHash); + return relationship; + } + + get prevNode(): RelatedNodeInfo | undefined { + const relationship = this.relationships[NodeRelationship.PREVIOUS]; + + if (Array.isArray(relationship)) { + throw new Error( + "Previous object must be a single RelatedNodeInfo object" + ); + } + + return relationship; + } + + get nextNode(): RelatedNodeInfo | undefined { + const relationship = this.relationships[NodeRelationship.NEXT]; + + if (Array.isArray(relationship)) { + throw new Error("Next object must be a single RelatedNodeInfo object"); + } + + return relationship; + } + + get parentNode(): RelatedNodeInfo | undefined { + const relationship = this.relationships[NodeRelationship.PARENT]; + + if (Array.isArray(relationship)) { + throw new Error("Parent object must be a single RelatedNodeInfo object"); + } + + return relationship; + } + + get childNodes(): RelatedNodeInfo[] | undefined { + const relationship = this.relationships[NodeRelationship.CHILD]; - this.relationships = { - source: null, - previous: null, - next: null, - parent: null, - child: [], + if (!Array.isArray(relationship)) { + throw new Error( + "Child object must be a an array of RelatedNodeInfo objects" + ); + } + + return relationship; + } + + getEmbedding(): number[] { + if (this.embedding === undefined) { + throw new Error("Embedding not set"); + } + + return this.embedding; + } + + asRelatedNodeInfo(): RelatedNodeInfo { + return { + nodeId: this.nodeId, + metadata: this.metadata, + hash: this.hash, }; } +} - getNodeInfo(): { [key: string]: any } { - return {}; +export class TextNode extends BaseNode { + text: string = ""; + startCharIdx?: number; + endCharIdx?: number; + // textTemplate: NOTE write your own formatter if needed + // metadataTemplate: NOTE write your own formatter if needed + metadataSeperator: string = "\n"; + + constructor(init?: Partial<TextNode>) { + super(init); + Object.assign(this, init); + } + + generateHash() { + throw new Error("Not implemented"); } - refDocId(): string | null { - return ""; + getType(): ObjectType { + return ObjectType.TEXT; } - prevNodeId(): string { - throw new Error("Node does not have previous node"); + getContent(metadataMode: MetadataMode = MetadataMode.NONE): string { + const metadataStr = this.getMetadataStr(metadataMode).trim(); + return `${metadataStr}\n\n${this.text}`.trim(); } - nextNodeId(): string { - throw new Error("Node does not have next node"); + getMetadataStr(metadataMode: MetadataMode): string { + if (metadataMode === MetadataMode.NONE) { + return ""; + } + + const usableMetadataKeys = new Set(Object.keys(this.metadata).sort()); + if (metadataMode === MetadataMode.LLM) { + for (const key of this.excludedLlmMetadataKeys) { + usableMetadataKeys.delete(key); + } + } else if (metadataMode === MetadataMode.EMBED) { + for (const key of this.excludedEmbedMetadataKeys) { + usableMetadataKeys.delete(key); + } + } + + return [...usableMetadataKeys] + .map((key) => `${key}: ${this.metadata[key]}`) + .join(this.metadataSeperator); } - parentNodeId(): string { - throw new Error("Node does not have parent node"); + setContent(value: string) { + this.text = value; } - childNodeIds(): string[] { - return []; + getNodeInfo() { + return { start: this.startCharIdx, end: this.endCharIdx }; } - getType() { - return NodeType.TEXT; + getText() { + return this.getContent(MetadataMode.NONE); } } -export interface NodeWithEmbedding { - node: Node; - embedding: number[]; +export class ImageNode extends TextNode { + image: string = ""; + + getType(): ObjectType { + return ObjectType.IMAGE; + } +} + +export class IndexNode extends TextNode { + indexId: string = ""; + + getType(): ObjectType { + return ObjectType.INDEX; + } } export interface NodeWithScore { - node: Node; + node: TextNode; score: number; } + +export class Document extends TextNode { + constructor(init?: Partial<Document>) { + super(init); + Object.assign(this, init); + } + + getType() { + return ObjectType.DOCUMENT; + } + + get docId() { + return this.id_; + } +} + +export class ImageDocument extends Document { + image?: string; +} diff --git a/packages/core/src/NodeParser.ts b/packages/core/src/NodeParser.ts index 1ec66f6f643abaef7cc2c6f3d98f9dc442b3a156..52db5dfbae7bfc1471d98099e0d7d646cb3c2756 100644 --- a/packages/core/src/NodeParser.ts +++ b/packages/core/src/NodeParser.ts @@ -1,5 +1,4 @@ -import { Document } from "./Document"; -import { Node } from "./Node"; +import { Document, NodeRelationship, TextNode } from "./Node"; import { SentenceSplitter } from "./TextSplitter"; export function getTextSplitsFromDocument( @@ -16,13 +15,13 @@ export function getNodesFromDocument( document: Document, textSplitter: SentenceSplitter ) { - let nodes: Node[] = []; + let nodes: TextNode[] = []; const textSplits = getTextSplitsFromDocument(document, textSplitter); textSplits.forEach((textSplit, index) => { - const node = new Node(textSplit); - node.relationships.source = document.getDocId(); + const node = new TextNode({ text: textSplit }); + node.relationships[NodeRelationship.SOURCE] = document.asRelatedNodeInfo(); nodes.push(node); }); @@ -30,7 +29,7 @@ export function getNodesFromDocument( } export interface NodeParser { - getNodesFromDocuments(documents: Document[]): Node[]; + getNodesFromDocuments(documents: Document[]): TextNode[]; } export class SimpleNodeParser implements NodeParser { textSplitter: SentenceSplitter; diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index 8d1db56cf872c51d7d434e926f78b11ecfcc31c5..baa2f3f0a7d0fb9d52830b06f534c8b944132e31 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -3,7 +3,7 @@ * NOTE this is a different interface compared to LlamaIndex Python * NOTE 2: we default to empty string to make it easy to calculate prompt sizes */ -export type SimplePrompt = (input: { [key: string]: string }) => string; +export type SimplePrompt = (input: Record<string, string>) => string; /* DEFAULT_TEXT_QA_PROMPT_TMPL = ( diff --git a/packages/core/src/Reader.ts b/packages/core/src/Reader.ts deleted file mode 100644 index 9ceea3c1f4b80ac84ccc147e8c85cf8b73262dd5..0000000000000000000000000000000000000000 --- a/packages/core/src/Reader.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { Document } from "./Document"; - -export interface BaseReader { - loadData(...args: any[]): Promise<Document>; -} - -export class SimpleDirectoryReader implements BaseReader { - async loadData(_options: any) { - return new Document("1", ""); - } -} - -export class PDFReader implements BaseReader { - async loadData(_options: any) { - return new Document("1", ""); - } -} diff --git a/packages/core/src/Response.ts b/packages/core/src/Response.ts index 03e0bb823c156bdc67538edade17f467ed8250c6..83121d77a8e3173bae41205a0c9ec579bac488b4 100644 --- a/packages/core/src/Response.ts +++ b/packages/core/src/Response.ts @@ -1,10 +1,10 @@ -import { Node } from "./Node"; +import { TextNode } from "./Node"; export class Response { response?: string; - sourceNodes: Node[]; + sourceNodes: TextNode[]; - constructor(response?: string, sourceNodes?: Node[]) { + constructor(response?: string, sourceNodes?: TextNode[]) { this.response = response; this.sourceNodes = sourceNodes || []; } diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index e9b03cc87dccc523d02309efbc71b355a291c5f8..7c955ac1a3645cc46d1218181061fefdf22acb88 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,5 +1,5 @@ import { VectorStoreIndex } from "./BaseIndex"; -import { BaseEmbedding, OpenAIEmbedding, getTopKEmbeddings } from "./Embedding"; +import { BaseEmbedding, getTopKEmbeddings } from "./Embedding"; import { NodeWithScore } from "./Node"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; @@ -23,9 +23,9 @@ export class VectorIndexRetriever implements BaseRetriever { ); const [similarities, ids] = getTopKEmbeddings( queryEmbedding, - this.index.nodesWithEmbeddings.map((node) => node.embedding), + this.index.nodes.map((node) => node.getEmbedding()), undefined, - this.index.nodesWithEmbeddings.map((node) => node.node.docId) + this.index.nodes.map((node) => node.id_) ); let nodesWithScores: NodeWithScore[] = []; diff --git a/packages/core/src/readers/PDFReader.ts b/packages/core/src/readers/PDFReader.ts index 91b9f137975918d28c369cfbbbbba3897585fbe0..066043471af9b79ba64ebe4c3575c0755431b020 100644 --- a/packages/core/src/readers/PDFReader.ts +++ b/packages/core/src/readers/PDFReader.ts @@ -1,4 +1,4 @@ -import { Document } from "../Document"; +import { Document } from "../Node"; import { BaseReader } from "./base"; import { GenericFileSystem } from "../storage/FileSystem"; import { DEFAULT_FS } from "../storage/constants"; @@ -12,6 +12,6 @@ export default class PDFReader implements BaseReader { ): Promise<Document[]> { let dataBuffer = (await fs.readFile(file)) as any; const data = await pdfParse(dataBuffer); - return [new Document(data.text, file)]; + return [new Document({ text: data.text, id_: file })]; } } diff --git a/packages/core/src/readers/SimpleDirectoryReader.ts b/packages/core/src/readers/SimpleDirectoryReader.ts index d4b850489138865581b6004b53c9c8865203dd7e..ef5b2d8ec77da1a298da53b475e69f9ba8f8b784 100644 --- a/packages/core/src/readers/SimpleDirectoryReader.ts +++ b/packages/core/src/readers/SimpleDirectoryReader.ts @@ -1,5 +1,5 @@ import _ from "lodash"; -import { Document } from "../Document"; +import { Document } from "../Node"; import { BaseReader } from "./base"; import { CompleteFileSystem, walk } from "../storage/FileSystem"; import { DEFAULT_FS } from "../storage/constants"; @@ -11,11 +11,11 @@ export class TextFileReader implements BaseReader { fs: CompleteFileSystem = DEFAULT_FS as CompleteFileSystem ): Promise<Document[]> { const dataBuffer = await fs.readFile(file, "utf-8"); - return [new Document(dataBuffer, file)]; + return [new Document({ text: dataBuffer, id_: file })]; } } -const FILE_EXT_TO_READER: { [key: string]: BaseReader } = { +const FILE_EXT_TO_READER: Record<string, BaseReader> = { txt: new TextFileReader(), pdf: new PDFReader(), }; @@ -24,7 +24,7 @@ export type SimpleDirectoryReaderLoadDataProps = { directoryPath: string; fs?: CompleteFileSystem; defaultReader?: BaseReader | null; - fileExtToReader?: { [key: string]: BaseReader }; + fileExtToReader?: Record<string, BaseReader>; }; export default class SimpleDirectoryReader implements BaseReader { diff --git a/packages/core/src/storage/FileSystem.ts b/packages/core/src/storage/FileSystem.ts index fbf8e2e2f6bbf63a94f6ab74ee824c7d49c7dd48..ce3d632182fcaf5c7854dcc997b6d35c5814d02a 100644 --- a/packages/core/src/storage/FileSystem.ts +++ b/packages/core/src/storage/FileSystem.ts @@ -22,7 +22,7 @@ export interface WalkableFileSystem { * A filesystem implementation that stores files in memory. */ export class InMemoryFileSystem implements GenericFileSystem { - private files: { [filepath: string]: any } = {}; + private files: Record<string, any> = {}; async writeFile(path: string, content: string, options?: any): Promise<void> { this.files[path] = _.cloneDeep(content); diff --git a/packages/core/src/storage/docStore/KVDocumentStore.ts b/packages/core/src/storage/docStore/KVDocumentStore.ts index f7809ba59e58c00e5bccd0332299cd8d23ab3aa4..64b9780d762ce260226364747a5a39fd6a19057d 100644 --- a/packages/core/src/storage/docStore/KVDocumentStore.ts +++ b/packages/core/src/storage/docStore/KVDocumentStore.ts @@ -1,5 +1,4 @@ -import { Node } from "../../Node"; -import { BaseDocument } from "../../Document"; +import { BaseNode, Document, ObjectType, TextNode } from "../../Node"; import { BaseDocumentStore, RefDocInfo } from "./types"; import { BaseKVStore } from "../kvStore/types"; import _, * as lodash from "lodash"; @@ -22,9 +21,9 @@ export class KVDocumentStore extends BaseDocumentStore { this.metadataCollection = `${namespace}/metadata`; } - async docs(): Promise<Record<string, BaseDocument>> { + async docs(): Promise<Record<string, BaseNode>> { let jsonDict = await this.kvstore.getAll(this.nodeCollection); - let docs: Record<string, BaseDocument> = {}; + let docs: Record<string, BaseNode> = {}; for (let key in jsonDict) { docs[key] = jsonToDoc(jsonDict[key] as Record<string, any>); } @@ -32,40 +31,39 @@ export class KVDocumentStore extends BaseDocumentStore { } async addDocuments( - docs: BaseDocument[], + docs: BaseNode[], allowUpdate: boolean = true ): Promise<void> { for (var idx = 0; idx < docs.length; idx++) { const doc = docs[idx]; - if (doc.getDocId() === null) { + if (doc.id_ === null) { throw new Error("doc_id not set"); } - if (!allowUpdate && (await this.documentExists(doc.getDocId()))) { + if (!allowUpdate && (await this.documentExists(doc.id_))) { throw new Error( - `doc_id ${doc.getDocId()} already exists. Set allow_update to True to overwrite.` + `doc_id ${doc.id_} already exists. Set allow_update to True to overwrite.` ); } - let nodeKey = doc.getDocId(); + let nodeKey = doc.id_; let data = docToJson(doc); await this.kvstore.put(nodeKey, data, this.nodeCollection); - let metadata: DocMetaData = { docHash: doc.getDocHash() }; + let metadata: DocMetaData = { docHash: doc.hash }; - if (doc instanceof Node && doc.refDocId() !== null) { - const nodeDoc = doc as Node; - let refDocInfo = (await this.getRefDocInfo(nodeDoc.refDocId()!)) || { + if (doc.getType() === ObjectType.TEXT && doc.sourceNode !== undefined) { + let refDocInfo = (await this.getRefDocInfo(doc.sourceNode.nodeId)) || { docIds: [], extraInfo: {}, }; - refDocInfo.docIds.push(nodeDoc.getDocId()); + refDocInfo.docIds.push(doc.id_); if (_.isEmpty(refDocInfo.extraInfo)) { - refDocInfo.extraInfo = nodeDoc.getNodeInfo() || {}; + refDocInfo.extraInfo = {}; } await this.kvstore.put( - nodeDoc.refDocId()!, + doc.sourceNode.nodeId, refDocInfo, this.refDocCollection ); - metadata.refDocId = nodeDoc.refDocId()!; + metadata.refDocId = doc.sourceNode.nodeId!; } this.kvstore.put(nodeKey, metadata, this.metadataCollection); @@ -75,7 +73,7 @@ export class KVDocumentStore extends BaseDocumentStore { async getDocument( docId: string, raiseError: boolean = true - ): Promise<BaseDocument | undefined> { + ): Promise<BaseNode | undefined> { let json = await this.kvstore.get(docId, this.nodeCollection); if (_.isNil(json)) { if (raiseError) { diff --git a/packages/core/src/storage/docStore/SimpleDocumentStore.ts b/packages/core/src/storage/docStore/SimpleDocumentStore.ts index 1f8fa7168155cc4a052090410f326c162d2c1595..f2554a3ee746a9c83c3573a0c11183c5020ee1df 100644 --- a/packages/core/src/storage/docStore/SimpleDocumentStore.ts +++ b/packages/core/src/storage/docStore/SimpleDocumentStore.ts @@ -11,7 +11,7 @@ import { DEFAULT_FS, } from "../constants"; -type SaveDict = { [key: string]: any }; +type SaveDict = Record<string, any>; export class SimpleDocumentStore extends KVDocumentStore { private kvStore: SimpleKVStore; diff --git a/packages/core/src/storage/docStore/types.ts b/packages/core/src/storage/docStore/types.ts index c58302adbe97c91ede9d61089aa7e332a194d1b6..e744fe0609b7eafc414d6d5430467e8ca84dae12 100644 --- a/packages/core/src/storage/docStore/types.ts +++ b/packages/core/src/storage/docStore/types.ts @@ -1,5 +1,4 @@ -import { Node } from "../../Node"; -import { BaseDocument } from "../../Document"; +import { BaseNode } from "../../Node"; import { GenericFileSystem } from "../FileSystem"; import { DEFAULT_PERSIST_DIR, @@ -10,7 +9,7 @@ const defaultPersistPath = `${DEFAULT_PERSIST_DIR}/${DEFAULT_DOC_STORE_PERSIST_F export interface RefDocInfo { docIds: string[]; - extraInfo: { [key: string]: any }; + extraInfo: Record<string, any>; } export abstract class BaseDocumentStore { @@ -23,14 +22,14 @@ export abstract class BaseDocumentStore { } // Main interface - abstract docs(): Promise<Record<string, BaseDocument>>; + abstract docs(): Promise<Record<string, BaseNode>>; - abstract addDocuments(docs: BaseDocument[], allowUpdate: boolean): void; + abstract addDocuments(docs: BaseNode[], allowUpdate: boolean): void; abstract getDocument( docId: string, raiseError: boolean - ): Promise<BaseDocument | undefined>; + ): Promise<BaseNode | undefined>; abstract deleteDocument(docId: string, raiseError: boolean): void; @@ -42,24 +41,22 @@ export abstract class BaseDocumentStore { abstract getDocumentHash(docId: string): Promise<string | undefined>; // Ref Docs - abstract getAllRefDocInfo(): Promise< - { [key: string]: RefDocInfo } | undefined - >; + abstract getAllRefDocInfo(): Promise<Record<string, RefDocInfo> | undefined>; abstract getRefDocInfo(refDocId: string): Promise<RefDocInfo | undefined>; abstract deleteRefDoc(refDocId: string, raiseError: boolean): Promise<void>; // Nodes - getNodes(nodeIds: string[], raiseError: boolean = true): Promise<Node[]> { + getNodes(nodeIds: string[], raiseError: boolean = true): Promise<BaseNode[]> { return Promise.all( nodeIds.map((nodeId) => this.getNode(nodeId, raiseError)) ); } - async getNode(nodeId: string, raiseError: boolean = true): Promise<Node> { + async getNode(nodeId: string, raiseError: boolean = true): Promise<BaseNode> { let doc = await this.getDocument(nodeId, raiseError); - if (!(doc instanceof Node)) { + if (!(doc instanceof BaseNode)) { throw new Error(`Document ${nodeId} is not a Node.`); } return doc; @@ -67,8 +64,8 @@ export abstract class BaseDocumentStore { async getNodeDict(nodeIdDict: { [index: number]: string; - }): Promise<{ [index: number]: Node }> { - let result: { [index: number]: Node } = {}; + }): Promise<Record<number, BaseNode>> { + let result: Record<number, BaseNode> = {}; for (let index in nodeIdDict) { result[index] = await this.getNode(nodeIdDict[index]); } diff --git a/packages/core/src/storage/docStore/utils.ts b/packages/core/src/storage/docStore/utils.ts index eea6c81a4a19431eac43f32fad4881e91c29f5ae..8c80a3c875d7a3e14fbc62463cddcc9fd26c571b 100644 --- a/packages/core/src/storage/docStore/utils.ts +++ b/packages/core/src/storage/docStore/utils.ts @@ -1,36 +1,35 @@ -import { Node } from "../../Node"; -import { BaseDocument, Document, NodeType } from "../../Document"; +import { BaseNode, Document, TextNode, ObjectType } from "../../Node"; const TYPE_KEY = "__type__"; const DATA_KEY = "__data__"; -export function docToJson(doc: BaseDocument): Record<string, any> { +export function docToJson(doc: BaseNode): Record<string, any> { return { [DATA_KEY]: JSON.stringify(doc), [TYPE_KEY]: doc.getType(), }; } -export function jsonToDoc(docDict: Record<string, any>): BaseDocument { +export function jsonToDoc(docDict: Record<string, any>): BaseNode { let docType = docDict[TYPE_KEY]; let dataDict = docDict[DATA_KEY]; - let doc: BaseDocument; + let doc: BaseNode; - if (docType === NodeType.DOCUMENT) { - doc = new Document( - dataDict.text, - dataDict.docId, - dataDict.embedding, - dataDict.docHash - ); - } else if (docType === NodeType.TEXT) { - const reslationships = dataDict.relationships; - doc = new Node( - reslationships.text, - reslationships.docId, - reslationships.embedding, - reslationships.docHash - ); + if (docType === ObjectType.DOCUMENT) { + doc = new Document({ + text: dataDict.text, + id_: dataDict.id_, + embedding: dataDict.embedding, + hash: dataDict.hash, + }); + } else if (docType === ObjectType.TEXT) { + const relationships = dataDict.relationships; + doc = new TextNode({ + text: relationships.text, + id_: relationships.id_, + embedding: relationships.embedding, + hash: relationships.hash, + }); } else { throw new Error(`Unknown doc type: ${docType}`); } diff --git a/packages/core/src/storage/kvStore/SimpleKVStore.ts b/packages/core/src/storage/kvStore/SimpleKVStore.ts index 196d7631a9d9d0fd7d9264d12cb6e1a9bd8f6b8a..bc4928325ed6824ec3357cbd04d81e30217661eb 100644 --- a/packages/core/src/storage/kvStore/SimpleKVStore.ts +++ b/packages/core/src/storage/kvStore/SimpleKVStore.ts @@ -4,9 +4,7 @@ import { DEFAULT_COLLECTION, DEFAULT_FS } from "../constants"; import * as _ from "lodash"; import { BaseKVStore } from "./types"; -export interface DataType { - [key: string]: { [key: string]: any }; -} +export type DataType = Record<string, Record<string, any>>; export class SimpleKVStore extends BaseKVStore { private data: DataType; diff --git a/packages/core/src/storage/kvStore/types.ts b/packages/core/src/storage/kvStore/types.ts index b6c3785fdcbdc53a94dd11a9f0685ae39f098f47..0d842bcbde82f5148e8a388a12261a8bc64ee5c8 100644 --- a/packages/core/src/storage/kvStore/types.ts +++ b/packages/core/src/storage/kvStore/types.ts @@ -1,16 +1,16 @@ import { GenericFileSystem } from "../FileSystem"; const defaultCollection = "data"; -type StoredValue = { [key: string]: any } | null; +type StoredValue = Record<string, any> | null; export abstract class BaseKVStore { abstract put( key: string, - val: { [key: string]: any }, + val: Record<string, any>, collection?: string ): Promise<void>; abstract get(key: string, collection?: string): Promise<StoredValue>; - abstract getAll(collection?: string): Promise<{ [key: string]: StoredValue }>; + abstract getAll(collection?: string): Promise<Record<string, StoredValue>>; abstract delete(key: string, collection?: string): Promise<boolean>; } diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index 489b13022a07943d9eb9b39172d33e0cbef6656b..e7c527b8393ec777e6d6f0de5e600c6ea33251b9 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -1,7 +1,6 @@ import _ from "lodash"; import { GenericFileSystem, exists } from "../FileSystem"; import { - NodeWithEmbedding, VectorStore, VectorStoreQuery, VectorStoreQueryMode, @@ -13,6 +12,7 @@ import { getTopKMMREmbeddings, } from "../../Embedding"; import { DEFAULT_PERSIST_DIR, DEFAULT_FS } from "../constants"; +import { TextNode } from "../../Node"; const LEARNER_MODES = new Set<VectorStoreQueryMode>([ VectorStoreQueryMode.SVM, @@ -23,8 +23,8 @@ const LEARNER_MODES = new Set<VectorStoreQueryMode>([ const MMR_MODE = VectorStoreQueryMode.MMR; class SimpleVectorStoreData { - embeddingDict: { [key: string]: number[] } = {}; - textIdToRefDocId: { [key: string]: string } = {}; + embeddingDict: Record<string, number[]> = {}; + textIdToRefDocId: Record<string, string> = {}; } export class SimpleVectorStore implements VectorStore { @@ -53,12 +53,18 @@ export class SimpleVectorStore implements VectorStore { return this.data.embeddingDict[textId]; } - add(embeddingResults: NodeWithEmbedding[]): string[] { + add(embeddingResults: TextNode[]): string[] { for (let result of embeddingResults) { - this.data.embeddingDict[result.id()] = result.embedding; - this.data.textIdToRefDocId[result.id()] = result.refDocId(); + this.data.embeddingDict[result.id_] = result.getEmbedding(); + + if (!result.sourceNode) { + console.error("Missing source node from TextNode."); + continue; + } + + this.data.textIdToRefDocId[result.id_] = result.sourceNode?.nodeId; } - return embeddingResults.map((result) => result.id()); + return embeddingResults.map((result) => result.id_); } delete(refDocId: string): void { diff --git a/packages/core/src/storage/vectorStore/types.ts b/packages/core/src/storage/vectorStore/types.ts index 74dada2135d437b33c1ad54997157b22ad4fda7a..6a3a42287b05559f7108eff09f1a02cbd4a85b92 100644 --- a/packages/core/src/storage/vectorStore/types.ts +++ b/packages/core/src/storage/vectorStore/types.ts @@ -1,8 +1,8 @@ -import { Node } from "../../Node"; +import { TextNode } from "../../Node"; import { GenericFileSystem } from "../FileSystem"; export interface NodeWithEmbedding { - node: Node; + node: TextNode; embedding: number[]; id(): string; @@ -10,7 +10,7 @@ export interface NodeWithEmbedding { } export interface VectorStoreQueryResult { - nodes?: Node[]; + nodes?: TextNode[]; similarities?: number[]; ids?: string[]; } @@ -68,7 +68,7 @@ export interface VectorStore { storesText: boolean; isEmbeddingQuery?: boolean; client(): any; - add(embeddingResults: NodeWithEmbedding[]): string[]; + add(embeddingResults: TextNode[]): string[]; delete(refDocId: string, deleteKwargs?: any): void; query(query: VectorStoreQuery, kwargs?: any): VectorStoreQueryResult; persist(persistPath: string, fs?: GenericFileSystem): void; diff --git a/packages/core/src/tests/Document.test.ts b/packages/core/src/tests/Document.test.ts index de799d517ce5f34315363f0082c8c0c90ae4e0dc..a0edb65aecf12c84e7c8ec6cd298955259415150 100644 --- a/packages/core/src/tests/Document.test.ts +++ b/packages/core/src/tests/Document.test.ts @@ -1,8 +1,8 @@ -import { Document } from "../Document"; +import { Document } from "../Node"; describe("Document", () => { test("initializes", () => { - const doc = new Document("text", "docId"); + const doc = new Document({ text: "text", id_: "docId" }); expect(doc).toBeDefined(); }); }); diff --git a/packages/core/tsconfig.json b/packages/core/tsconfig.json index 7ca71178d53835722a040fe55cb186962981aec6..059d82421d1946cd28447bcb7bfa231abb6e5bad 100644 --- a/packages/core/tsconfig.json +++ b/packages/core/tsconfig.json @@ -8,7 +8,8 @@ "skipLibCheck": true, "noEmit": true, "strict": true, - "lib": ["es2015", "dom"] + "lib": ["es2015", "dom"], + "target": "ES2015" }, "exclude": ["node_modules"] }