diff --git a/packages/core/src/BaseIndex.ts b/packages/core/src/BaseIndex.ts index 2b05ad588fa85c84c011db8302236987654279de..0a0d00dfae1982a6273ce072a80366a76132b2ce 100644 --- a/packages/core/src/BaseIndex.ts +++ b/packages/core/src/BaseIndex.ts @@ -1,21 +1,18 @@ -import { Document, TextNode } from "./Node"; -import { SimpleNodeParser } from "./NodeParser"; +import { Document, BaseNode, MetadataMode, NodeWithEmbedding } from "./Node"; import { BaseQueryEngine, RetrieverQueryEngine } from "./QueryEngine"; import { v4 as uuidv4 } from "uuid"; -import { VectorIndexRetriever } from "./Retriever"; -import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; -export class BaseIndex { - nodes: TextNode[] = []; - - constructor(nodes?: TextNode[]) { - this.nodes = nodes ?? []; - } -} - +import { BaseRetriever, VectorIndexRetriever } from "./Retriever"; +import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +import { + StorageContext, + storageContextFromDefaults, +} from "./storage/StorageContext"; +import { BaseDocumentStore } from "./storage/docStore/types"; +import { VectorStore } from "./storage/vectorStore/types"; export class IndexDict { indexId: string; summary?: string; - nodesDict: Record<string, TextNode> = {}; + nodesDict: Record<string, BaseNode> = {}; docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext constructor(indexId = uuidv4(), summary = undefined) { @@ -30,56 +27,147 @@ export class IndexDict { return this.summary; } - addNode(node: TextNode, textId?: string) { + addNode(node: BaseNode, textId?: string) { const vectorId = textId ?? node.id_; this.nodesDict[vectorId] = node; } } -export class VectorStoreIndex extends BaseIndex { - indexStruct: IndexDict; - embeddingService: BaseEmbedding; // FIXME replace with service context +export interface BaseIndexInit<T> { + serviceContext: ServiceContext; + storageContext: StorageContext; + docStore: BaseDocumentStore; + vectorStore: VectorStore; + indexStruct: T; +} +export abstract class BaseIndex<T> { + serviceContext: ServiceContext; + storageContext: StorageContext; + docStore: BaseDocumentStore; + vectorStore: VectorStore; + indexStruct: T; + + constructor(init: BaseIndexInit<T>) { + this.serviceContext = init.serviceContext; + this.storageContext = init.storageContext; + this.docStore = init.docStore; + this.vectorStore = init.vectorStore; + this.indexStruct = init.indexStruct; + } + + abstract asRetriever(): BaseRetriever; +} + +export interface VectorIndexOptions { + nodes?: BaseNode[]; + indexStruct?: IndexDict; + serviceContext?: ServiceContext; + storageContext?: StorageContext; +} + +export class VectorStoreIndex extends BaseIndex<IndexDict> { + private constructor(init: BaseIndexInit<IndexDict>) { + super(init); + } - constructor(nodes: TextNode[]) { - super(nodes); - this.indexStruct = new IndexDict(); + static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> { + const storageContext = + options.storageContext ?? (await storageContextFromDefaults({})); + const serviceContext = + options.serviceContext ?? serviceContextFromDefaults({}); + const docStore = storageContext.docStore; + const vectorStore = storageContext.vectorStore; - if (nodes !== undefined) { - this.buildIndexFromNodes(); + let indexStruct: IndexDict; + if (options.indexStruct) { + if (options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex with both nodes and indexStruct" + ); + } + indexStruct = options.indexStruct; + } else { + if (!options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex without nodes or indexStruct" + ); + } + indexStruct = await VectorStoreIndex.buildIndexFromNodes( + options.nodes, + serviceContext, + vectorStore + ); } - this.embeddingService = new OpenAIEmbedding(); + return new VectorStoreIndex({ + storageContext, + serviceContext, + docStore, + vectorStore, + indexStruct, + }); } - async getNodeEmbeddingResults(logProgress = false) { - for (let i = 0; i < this.nodes.length; ++i) { - const node = this.nodes[i]; + static async agetNodeEmbeddingResults( + nodes: BaseNode[], + serviceContext: ServiceContext, + logProgress = false + ) { + const nodesWithEmbeddings: NodeWithEmbedding[] = []; + + for (let i = 0; i < nodes.length; ++i) { + const node = nodes[i]; if (logProgress) { - console.log(`getting embedding for node ${i}/${this.nodes.length}`); + console.log(`getting embedding for node ${i}/${nodes.length}`); } - const embedding = await this.embeddingService.aGetTextEmbedding( - node.getText() + const embedding = await serviceContext.embedModel.aGetTextEmbedding( + node.getContent(MetadataMode.EMBED) ); - node.embedding = embedding; + nodesWithEmbeddings.push({ node, embedding }); } + + return nodesWithEmbeddings; } - buildIndexFromNodes() { - for (const node of this.nodes) { - this.indexStruct.addNode(node); - } + static async buildIndexFromNodes( + nodes: BaseNode[], + serviceContext: ServiceContext, + vectorStore: VectorStore + ): Promise<IndexDict> { + const embeddingResults = await this.agetNodeEmbeddingResults( + nodes, + serviceContext + ); + + vectorStore.add(embeddingResults); + + throw new Error("not implemented"); } - static async fromDocuments(documents: Document[]): Promise<VectorStoreIndex> { - const nodeParser = new SimpleNodeParser(); // FIXME use service context - const nodes = nodeParser.getNodesFromDocuments(documents); - const index = new VectorStoreIndex(nodes); - await index.getNodeEmbeddingResults(); + static async fromDocuments( + documents: Document[], + storageContext?: StorageContext, + serviceContext?: ServiceContext + ): Promise<VectorStoreIndex> { + storageContext = storageContext ?? (await storageContextFromDefaults({})); + serviceContext = serviceContext ?? serviceContextFromDefaults({}); + const docStore = storageContext.docStore; + + for (const doc of documents) { + docStore.setDocumentHash(doc.id_, doc.hash); + } + + const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents); + const index = await VectorStoreIndex.init({ + nodes, + storageContext, + serviceContext, + }); return index; } asRetriever(): VectorIndexRetriever { - return new VectorIndexRetriever(this, this.embeddingService); + return new VectorIndexRetriever(this); } asQueryEngine(): BaseQueryEngine { diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index 169413b5b377f31960ae65cd90cd46205cd2900d..95d3081c8a7158d85e6207d37d075181fdc45951 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -136,7 +136,7 @@ export class TextNode extends BaseNode { endCharIdx?: number; // textTemplate: NOTE write your own formatter if needed // metadataTemplate: NOTE write your own formatter if needed - metadataSeperator: string = "\n"; + metadataSeparator: string = "\n"; constructor(init?: Partial<TextNode>) { super(init); @@ -174,7 +174,7 @@ export class TextNode extends BaseNode { return [...usableMetadataKeys] .map((key) => `${key}: ${this.metadata[key]}`) - .join(this.metadataSeperator); + .join(this.metadataSeparator); } setContent(value: string) { @@ -206,11 +206,6 @@ export class IndexNode extends TextNode { } } -export interface NodeWithScore { - node: TextNode; - score: number; -} - export class Document extends TextNode { constructor(init?: Partial<Document>) { super(init); @@ -229,3 +224,13 @@ export class Document extends TextNode { export class ImageDocument extends Document { image?: string; } + +export interface NodeWithScore { + node: BaseNode; + score: number; +} + +export interface NodeWithEmbedding { + node: BaseNode; + embedding: number[]; +} diff --git a/packages/core/src/Response.ts b/packages/core/src/Response.ts index 83121d77a8e3173bae41205a0c9ec579bac488b4..9ef1b81e4e2f4774d57954a9db4bd4d48c736323 100644 --- a/packages/core/src/Response.ts +++ b/packages/core/src/Response.ts @@ -1,10 +1,10 @@ -import { TextNode } from "./Node"; +import { BaseNode } from "./Node"; export class Response { response?: string; - sourceNodes: TextNode[]; + sourceNodes: BaseNode[]; - constructor(response?: string, sourceNodes?: TextNode[]) { + constructor(response?: string, sourceNodes?: BaseNode[]) { this.response = response; this.sourceNodes = sourceNodes || []; } diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/ResponseSynthesizer.ts index 0250ccd7eaa405e801dbe0f4ef4f8b7e5e1a208d..8e2bc4b63c75445a8942903656d9b4b0ceb781a2 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/ResponseSynthesizer.ts @@ -1,5 +1,5 @@ import { ChatGPTLLMPredictor } from "./LLMPredictor"; -import { NodeWithScore } from "./Node"; +import { MetadataMode, NodeWithScore } from "./Node"; import { SimplePrompt, defaultRefinePrompt, @@ -190,7 +190,9 @@ export class ResponseSynthesizer { } async asynthesize(query: string, nodes: NodeWithScore[]) { - let textChunks: string[] = nodes.map((node) => node.node.text); + let textChunks: string[] = nodes.map((node) => + node.node.getContent(MetadataMode.NONE) + ); const response = await this.responseBuilder.agetResponse(query, textChunks); return new Response( response, diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index 7c955ac1a3645cc46d1218181061fefdf22acb88..d1c942254f9bdf18bda643840bc6b184ddb23a6f 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,7 +1,12 @@ import { VectorStoreIndex } from "./BaseIndex"; import { BaseEmbedding, getTopKEmbeddings } from "./Embedding"; import { NodeWithScore } from "./Node"; +import { ServiceContext } from "./ServiceContext"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; +import { + VectorStoreQuery, + VectorStoreQueryMode, +} from "./storage/vectorStore/types"; export interface BaseRetriever { aretrieve(query: string): Promise<any>; @@ -10,31 +15,30 @@ export interface BaseRetriever { export class VectorIndexRetriever implements BaseRetriever { index: VectorStoreIndex; similarityTopK = DEFAULT_SIMILARITY_TOP_K; - embeddingService: BaseEmbedding; + private serviceContext: ServiceContext; - constructor(index: VectorStoreIndex, embeddingService: BaseEmbedding) { + constructor(index: VectorStoreIndex) { this.index = index; - this.embeddingService = embeddingService; + this.serviceContext = this.index.serviceContext; } async aretrieve(query: string): Promise<NodeWithScore[]> { - const queryEmbedding = await this.embeddingService.aGetQueryEmbedding( - query - ); - const [similarities, ids] = getTopKEmbeddings( - queryEmbedding, - this.index.nodes.map((node) => node.getEmbedding()), - undefined, - this.index.nodes.map((node) => node.id_) - ); + const queryEmbedding = + await this.serviceContext.embedModel.aGetQueryEmbedding(query); - let nodesWithScores: NodeWithScore[] = []; + const q: VectorStoreQuery = { + queryEmbedding: queryEmbedding, + mode: VectorStoreQueryMode.DEFAULT, + similarityTopK: this.similarityTopK, + }; + const result = this.index.vectorStore.query(q); - for (let i = 0; i < ids.length; i++) { - const node = this.index.indexStruct.nodesDict[ids[i]]; + let nodesWithScores: NodeWithScore[] = []; + for (let i = 0; i < result.ids.length; i++) { + const node = this.index.indexStruct.nodesDict[result.ids[i]]; nodesWithScores.push({ node: node, - score: similarities[i], + score: result.similarities[i], }); } diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 6a12a08330f02e494f8d930ea4b0b9404b636089..2a635740ebb2d9d7fe3e61df1fd453915e98bd54 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -21,7 +21,7 @@ export interface ServiceContextOptions { nodeParser?: NodeParser; // NodeParser arguments chunkSize?: number; - chunkOverlap: number; + chunkOverlap?: number; } export function serviceContextFromDefaults(options: ServiceContextOptions) { diff --git a/packages/core/src/storage/StorageContext.ts b/packages/core/src/storage/StorageContext.ts index 43b6086572362c0d0ee3b5c348bb23904aa54871..d7c5e03bcd3254586aab19c2427b9f3efed528cc 100644 --- a/packages/core/src/storage/StorageContext.ts +++ b/packages/core/src/storage/StorageContext.ts @@ -12,9 +12,9 @@ import { } from "./constants"; export interface StorageContext { - docStore?: BaseDocumentStore; - indexStore?: BaseIndexStore; - vectorStore?: VectorStore; + docStore: BaseDocumentStore; + indexStore: BaseIndexStore; + vectorStore: VectorStore; } type BuilderParams = { diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index e7c527b8393ec777e6d6f0de5e600c6ea33251b9..d0d9cab5181aa0e5ae2ef162d42bdc9b2c5d134b 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -12,7 +12,7 @@ import { getTopKMMREmbeddings, } from "../../Embedding"; import { DEFAULT_PERSIST_DIR, DEFAULT_FS } from "../constants"; -import { TextNode } from "../../Node"; +import { NodeWithEmbedding } from "../../Node"; const LEARNER_MODES = new Set<VectorStoreQueryMode>([ VectorStoreQueryMode.SVM, @@ -53,18 +53,19 @@ export class SimpleVectorStore implements VectorStore { return this.data.embeddingDict[textId]; } - add(embeddingResults: TextNode[]): string[] { + add(embeddingResults: NodeWithEmbedding[]): string[] { for (let result of embeddingResults) { - this.data.embeddingDict[result.id_] = result.getEmbedding(); + this.data.embeddingDict[result.node.id_] = result.embedding; - if (!result.sourceNode) { + if (!result.node.sourceNode) { console.error("Missing source node from TextNode."); continue; } - this.data.textIdToRefDocId[result.id_] = result.sourceNode?.nodeId; + this.data.textIdToRefDocId[result.node.id_] = + result.node.sourceNode?.nodeId; } - return embeddingResults.map((result) => result.id_); + return embeddingResults.map((result) => result.node.id_); } delete(refDocId: string): void { diff --git a/packages/core/src/storage/vectorStore/types.ts b/packages/core/src/storage/vectorStore/types.ts index 6a3a42287b05559f7108eff09f1a02cbd4a85b92..884c5475901648b141cd9ee5cbd7509e6dc9f15f 100644 --- a/packages/core/src/storage/vectorStore/types.ts +++ b/packages/core/src/storage/vectorStore/types.ts @@ -1,18 +1,11 @@ -import { TextNode } from "../../Node"; +import { BaseNode } from "../../Node"; import { GenericFileSystem } from "../FileSystem"; - -export interface NodeWithEmbedding { - node: TextNode; - embedding: number[]; - - id(): string; - refDocId(): string; -} +import { NodeWithEmbedding } from "../../Node"; export interface VectorStoreQueryResult { - nodes?: TextNode[]; - similarities?: number[]; - ids?: string[]; + nodes?: BaseNode[]; + similarities: number[]; + ids: string[]; } export enum VectorStoreQueryMode { @@ -68,7 +61,7 @@ export interface VectorStore { storesText: boolean; isEmbeddingQuery?: boolean; client(): any; - add(embeddingResults: TextNode[]): string[]; + add(embeddingResults: NodeWithEmbedding[]): string[]; delete(refDocId: string, deleteKwargs?: any): void; query(query: VectorStoreQuery, kwargs?: any): VectorStoreQueryResult; persist(persistPath: string, fs?: GenericFileSystem): void;