diff --git a/.gitignore b/.gitignore index d1595af422c385474a45c132cfb3bdea332b3467..2641a07fa5489a4a28b76c5e83474c459749f8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ yarn-error.log* # vercel .vercel + +storage/ diff --git a/apps/simple/listIndex.ts b/apps/simple/listIndex.ts new file mode 100644 index 0000000000000000000000000000000000000000..5b7a5203b36431a07ce367016d3f144dce13748d --- /dev/null +++ b/apps/simple/listIndex.ts @@ -0,0 +1,17 @@ +import { Document } from "@llamaindex/core/src/Node"; +import { ListIndex } from "@llamaindex/core/src/index/list"; +import essay from "./essay"; + +async function main() { + const document = new Document({ text: essay }); + const index = await ListIndex.fromDocuments([document]); + const queryEngine = index.asQueryEngine(); + const response = await queryEngine.aquery( + "What did the author do growing up?" + ); + console.log(response.toString()); +} + +main().catch((e: Error) => { + console.error(e, e.stack); +}); diff --git a/apps/simple/index.ts b/apps/simple/vectorIndex.ts similarity index 88% rename from apps/simple/index.ts rename to apps/simple/vectorIndex.ts index 733bb7f07fedcf7c8b74309fe384b214f7cec6cc..d05b5874984e3846ee2de6a55ceb1c29875a827a 100644 --- a/apps/simple/index.ts +++ b/apps/simple/vectorIndex.ts @@ -2,7 +2,7 @@ import { Document } from "@llamaindex/core/src/Node"; import { VectorStoreIndex } from "@llamaindex/core/src/BaseIndex"; import essay from "./essay"; -(async () => { +async function main() { const document = new Document({ text: essay }); const index = await VectorStoreIndex.fromDocuments([document]); const queryEngine = index.asQueryEngine(); @@ -10,4 +10,6 @@ import essay from "./essay"; "What did the author do growing up?" ); console.log(response.toString()); -})(); +} + +main().catch(console.error); diff --git a/packages/core/package.json b/packages/core/package.json index b303d036048822c443c8fe92241dd8c3f6a79e97..f0f5c71ed1cb4a98cf735be3502ae5aae9923e56 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -10,6 +10,9 @@ "uuid": "^9.0.0", "wink-nlp": "^1.14.1" }, + "engines": { + "node": ">=18.0.0" + }, "main": "src/index.ts", "types": "src/index.ts", "scripts": { diff --git a/packages/core/src/BaseIndex.ts b/packages/core/src/BaseIndex.ts index dd6e7461b20ea86a6133aad94dd0adf71a014748..0889be7585703289580c1e6b4921e9f1e8eee6e9 100644 --- a/packages/core/src/BaseIndex.ts +++ b/packages/core/src/BaseIndex.ts @@ -9,11 +9,11 @@ import { } from "./storage/StorageContext"; import { BaseDocumentStore } from "./storage/docStore/types"; import { VectorStore } from "./storage/vectorStore/types"; -export class IndexDict { +import { BaseIndexStore } from "./storage/indexStore/types"; + +export abstract class IndexStruct { indexId: string; summary?: string; - nodesDict: Record<string, BaseNode> = {}; - docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext constructor(indexId = uuidv4(), summary = undefined) { this.indexId = indexId; @@ -26,6 +26,18 @@ export class IndexDict { } return this.summary; } +} + +export class IndexDict extends IndexStruct { + nodesDict: Record<string, BaseNode> = {}; + docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext + + getSummary(): string { + if (this.summary === undefined) { + throw new Error("summary field of the index dict is not set"); + } + return this.summary; + } addNode(node: BaseNode, textId?: string) { const vectorId = textId ?? node.id_; @@ -33,18 +45,28 @@ export class IndexDict { } } +export class IndexList extends IndexStruct { + nodes: string[] = []; + + addNode(node: BaseNode) { + this.nodes.push(node.id_); + } +} + export interface BaseIndexInit<T> { serviceContext: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore: VectorStore; + vectorStore?: VectorStore; + indexStore?: BaseIndexStore; indexStruct: T; } export abstract class BaseIndex<T> { serviceContext: ServiceContext; storageContext: StorageContext; docStore: BaseDocumentStore; - vectorStore: VectorStore; + vectorStore?: VectorStore; + indexStore?: BaseIndexStore; indexStruct: T; constructor(init: BaseIndexInit<T>) { @@ -52,6 +74,7 @@ export abstract class BaseIndex<T> { this.storageContext = init.storageContext; this.docStore = init.docStore; this.vectorStore = init.vectorStore; + this.indexStore = init.indexStore; this.indexStruct = init.indexStruct; } @@ -65,9 +88,16 @@ export interface VectorIndexOptions { storageContext?: StorageContext; } +interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { + vectorStore: VectorStore; +} + export class VectorStoreIndex extends BaseIndex<IndexDict> { - private constructor(init: BaseIndexInit<IndexDict>) { + vectorStore: VectorStore; + + private constructor(init: VectorIndexConstructorProps) { super(init); + this.vectorStore = init.vectorStore; } static async init(options: VectorIndexOptions): Promise<VectorStoreIndex> { diff --git a/packages/core/src/Prompt.ts b/packages/core/src/Prompt.ts index baa2f3f0a7d0fb9d52830b06f534c8b944132e31..74d02c4bdbfac22fefd424dd78f1c889f4de820a 100644 --- a/packages/core/src/Prompt.ts +++ b/packages/core/src/Prompt.ts @@ -80,3 +80,37 @@ ${context} ------------ Given the new context, refine the original answer to better answer the question. If the context isn't useful, return the original answer.`; }; + +export const defaultChoiceSelectPrompt: SimplePrompt = (input) => { + const { context = "", query = "" } = input; + + return `A list of documents is shown below. Each document has a number next to it along +with a summary of the document. A question is also provided. +Respond with the numbers of the documents +you should consult to answer the question, in order of relevance, as well +as the relevance score. The relevance score is a number from 1-10 based on +how relevant you think the document is to the question. +Do not include any documents that are not relevant to the question. +Example format: +Document 1: +<summary of document 1> + +Document 2: +<summary of document 2> + +... + +Document 10:\n<summary of document 10> + +Question: <question> +Answer: +Doc: 9, Relevance: 7 +Doc: 3, Relevance: 4 +Doc: 7, Relevance: 3 + +Let's try this now: + +${context} +Question: ${query} +Answer:`; +}; diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index d1c942254f9bdf18bda643840bc6b184ddb23a6f..dfb72e88f577c0cbc1ae00421ce28eb18c7e345d 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,5 +1,4 @@ import { VectorStoreIndex } from "./BaseIndex"; -import { BaseEmbedding, getTopKEmbeddings } from "./Embedding"; import { NodeWithScore } from "./Node"; import { ServiceContext } from "./ServiceContext"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; diff --git a/packages/core/src/index/list/ListIndex.ts b/packages/core/src/index/list/ListIndex.ts new file mode 100644 index 0000000000000000000000000000000000000000..56a17e0fe07c4d8348c3e503b89b4b77bddd0f65 --- /dev/null +++ b/packages/core/src/index/list/ListIndex.ts @@ -0,0 +1,166 @@ +import { BaseNode, Document } from "../../Node"; +import { BaseIndex, BaseIndexInit, IndexList } from "../../BaseIndex"; +import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; +import { + StorageContext, + storageContextFromDefaults, +} from "../../storage/StorageContext"; +import { BaseRetriever } from "../../Retriever"; +import { ListIndexRetriever } from "./ListIndexRetriever"; +import { + ServiceContext, + serviceContextFromDefaults, +} from "../../ServiceContext"; +import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; +import _ from "lodash"; + +export enum ListRetrieverMode { + DEFAULT = "default", + // EMBEDDING = "embedding", + LLM = "llm", +} + +export interface ListIndexOptions { + nodes?: BaseNode[]; + indexStruct?: IndexList; + serviceContext?: ServiceContext; + storageContext?: StorageContext; +} + +export class ListIndex extends BaseIndex<IndexList> { + constructor(init: BaseIndexInit<IndexList>) { + super(init); + } + + static async init(options: ListIndexOptions): Promise<ListIndex> { + const storageContext = + options.storageContext ?? (await storageContextFromDefaults({})); + const serviceContext = + options.serviceContext ?? serviceContextFromDefaults({}); + const { docStore, indexStore } = storageContext; + + let indexStruct: IndexList; + if (options.indexStruct) { + if (options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex with both nodes and indexStruct" + ); + } + indexStruct = options.indexStruct; + } else { + if (!options.nodes) { + throw new Error( + "Cannot initialize VectorStoreIndex without nodes or indexStruct" + ); + } + indexStruct = ListIndex._buildIndexFromNodes( + options.nodes, + storageContext.docStore + ); + } + + return new ListIndex({ + storageContext, + serviceContext, + docStore, + indexStore, + indexStruct, + }); + } + + static async fromDocuments( + documents: Document[], + storageContext?: StorageContext, + serviceContext?: ServiceContext + ): Promise<ListIndex> { + storageContext = storageContext ?? (await storageContextFromDefaults({})); + serviceContext = serviceContext ?? serviceContextFromDefaults({}); + const docStore = storageContext.docStore; + + docStore.addDocuments(documents, true); + for (const doc of documents) { + docStore.setDocumentHash(doc.id_, doc.hash); + } + + const nodes = serviceContext.nodeParser.getNodesFromDocuments(documents); + const index = await ListIndex.init({ + nodes, + storageContext, + serviceContext, + }); + return index; + } + + asRetriever( + mode: ListRetrieverMode = ListRetrieverMode.DEFAULT + ): BaseRetriever { + switch (mode) { + case ListRetrieverMode.DEFAULT: + return new ListIndexRetriever(this); + case ListRetrieverMode.LLM: + throw new Error(`Support for LLM retriever mode is not implemented`); + default: + throw new Error(`Unknown retriever mode: ${mode}`); + } + } + + asQueryEngine( + mode: ListRetrieverMode = ListRetrieverMode.DEFAULT + ): BaseQueryEngine { + return new RetrieverQueryEngine(this.asRetriever()); + } + + static _buildIndexFromNodes( + nodes: BaseNode[], + docStore: BaseDocumentStore, + indexStruct?: IndexList + ): IndexList { + indexStruct = indexStruct || new IndexList(); + + docStore.addDocuments(nodes, true); + for (const node of nodes) { + indexStruct.addNode(node); + } + + return indexStruct; + } + + protected _insert(nodes: BaseNode[]): void { + for (const node of nodes) { + this.indexStruct.addNode(node); + } + } + + protected _deleteNode(nodeId: string): void { + this.indexStruct.nodes = this.indexStruct.nodes.filter( + (existingNodeId: string) => existingNodeId !== nodeId + ); + } + + async getRefDocInfo(): Promise<Record<string, RefDocInfo>> { + const nodeDocIds = this.indexStruct.nodes; + const nodes = await this.docStore.getNodes(nodeDocIds); + + const refDocInfoMap: Record<string, RefDocInfo> = {}; + + for (const node of nodes) { + const refNode = node.sourceNode; + if (_.isNil(refNode)) { + continue; + } + + const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId); + + if (_.isNil(refDocInfo)) { + continue; + } + + refDocInfoMap[refNode.nodeId] = refDocInfo; + } + + return refDocInfoMap; + } +} + +// Legacy +export type GPTListIndex = ListIndex; diff --git a/packages/core/src/index/list/ListIndexRetriever.ts b/packages/core/src/index/list/ListIndexRetriever.ts new file mode 100644 index 0000000000000000000000000000000000000000..33e7420755868ce4dc30028160ce620f3d3dce61 --- /dev/null +++ b/packages/core/src/index/list/ListIndexRetriever.ts @@ -0,0 +1,96 @@ +import { BaseRetriever } from "../../Retriever"; +import { NodeWithScore } from "../../Node"; +import { ListIndex } from "./ListIndex"; +import { ServiceContext } from "../../ServiceContext"; +import { + NodeFormatterFunction, + ChoiceSelectParserFunction, + defaultFormatNodeBatchFn, + defaultParseChoiceSelectAnswerFn, +} from "./utils"; +import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; +import _ from "lodash"; + +/** + * Simple retriever for ListIndex that returns all nodes + */ +export class ListIndexRetriever implements BaseRetriever { + index: ListIndex; + + constructor(index: ListIndex) { + this.index = index; + } + + async aretrieve(query: string): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const nodes = await this.index.docStore.getNodes(nodeIds); + return nodes.map((node) => ({ + node: node, + score: 1, + })); + } +} + +/** + * LLM retriever for ListIndex. + */ +export class ListIndexLLMRetriever implements BaseRetriever { + index: ListIndex; + choiceSelectPrompt: SimplePrompt; + choiceBatchSize: number; + formatNodeBatchFn: NodeFormatterFunction; + parseChoiceSelectAnswerFn: ChoiceSelectParserFunction; + serviceContext: ServiceContext; + + constructor( + index: ListIndex, + choiceSelectPrompt?: SimplePrompt, + choiceBatchSize: number = 10, + formatNodeBatchFn?: NodeFormatterFunction, + parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction, + serviceContext?: ServiceContext + ) { + this.index = index; + this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt; + this.choiceBatchSize = choiceBatchSize; + this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn; + this.parseChoiceSelectAnswerFn = + parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn; + this.serviceContext = serviceContext || index.serviceContext; + } + + async aretrieve(query: string): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const results: NodeWithScore[] = []; + + for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) { + const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize); + const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch); + + const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); + const input = { context: fmtBatchStr, query: query }; + const rawResponse = await this.serviceContext.llmPredictor.apredict( + this.choiceSelectPrompt, + input + ); + + // parseResult is a map from doc number to relevance score + const parseResult = this.parseChoiceSelectAnswerFn( + rawResponse, + nodesBatch.length + ); + const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => { + return `${idx}` in parseResult; + }); + + const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds); + const nodeWithScores = choiceNodes.map((node, i) => ({ + node: node, + score: _.get(parseResult, `${i + 1}`, 1), + })); + + results.push(...nodeWithScores); + } + return results; + } +} diff --git a/packages/core/src/index/list/index.ts b/packages/core/src/index/list/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..f8d0b8d5eae44cf561fd8483fbabf6dc716260d0 --- /dev/null +++ b/packages/core/src/index/list/index.ts @@ -0,0 +1,5 @@ +export { ListIndex, ListRetrieverMode } from "./ListIndex"; +export { + ListIndexRetriever, + ListIndexLLMRetriever, +} from "./ListIndexRetriever"; diff --git a/packages/core/src/index/list/utils.ts b/packages/core/src/index/list/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..b7a1d3f8fddc0af2c587ef1db2c8ff0db01dbfce --- /dev/null +++ b/packages/core/src/index/list/utils.ts @@ -0,0 +1,73 @@ +import { BaseNode, MetadataMode } from "../../Node"; +import _ from "lodash"; + +export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string; +export const defaultFormatNodeBatchFn: NodeFormatterFunction = ( + summaryNodes: BaseNode[] +): string => { + return summaryNodes + .map((node, idx) => { + return ` +Document ${idx + 1}: +${node.getContent(MetadataMode.LLM)} + `.trim(); + }) + .join("\n\n"); +}; + +// map from document number to its relevance score +export type ChoiceSelectParseResult = { [docNumber: number]: number }; +export type ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr?: boolean +) => ChoiceSelectParseResult; + +export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr: boolean = false +): ChoiceSelectParseResult => { + // split the line into the answer number and relevance score portions + const lineTokens: string[][] = answer + .split("\n") + .map((line: string) => { + let lineTokens = line.split(","); + if (lineTokens.length !== 2) { + if (raiseErr) { + throw new Error( + `Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>` + ); + } else { + return null; + } + } + return lineTokens; + }) + .filter((lineTokens) => !_.isNil(lineTokens)) as string[][]; + + // parse the answer number and relevance score + return lineTokens.reduce( + (parseResult: ChoiceSelectParseResult, lineToken: string[]) => { + try { + let docNum = parseInt(lineToken[0].split(":")[1].trim()); + let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim()); + if (docNum < 1 || docNum > numChoices) { + if (raiseErr) { + throw new Error( + `Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}` + ); + } else { + parseResult[docNum] = answerRelevance; + } + } + } catch (e) { + if (raiseErr) { + throw e; + } + } + return parseResult; + }, + {} + ); +}; diff --git a/packages/core/src/storage/docStore/KVDocumentStore.ts b/packages/core/src/storage/docStore/KVDocumentStore.ts index 64b9780d762ce260226364747a5a39fd6a19057d..027672e6d276a3eac279a0aba37e684d905aab80 100644 --- a/packages/core/src/storage/docStore/KVDocumentStore.ts +++ b/packages/core/src/storage/docStore/KVDocumentStore.ts @@ -77,7 +77,7 @@ export class KVDocumentStore extends BaseDocumentStore { let json = await this.kvstore.get(docId, this.nodeCollection); if (_.isNil(json)) { if (raiseError) { - throw new Error(`doc_id ${docId} not found.`); + throw new Error(`docId ${docId} not found.`); } else { return; } diff --git a/packages/core/src/storage/docStore/utils.ts b/packages/core/src/storage/docStore/utils.ts index 8c80a3c875d7a3e14fbc62463cddcc9fd26c571b..a7329df67e14d10343234e49f6538c324b452641 100644 --- a/packages/core/src/storage/docStore/utils.ts +++ b/packages/core/src/storage/docStore/utils.ts @@ -23,12 +23,11 @@ export function jsonToDoc(docDict: Record<string, any>): BaseNode { hash: dataDict.hash, }); } else if (docType === ObjectType.TEXT) { - const relationships = dataDict.relationships; + console.log({ dataDict }); doc = new TextNode({ - text: relationships.text, - id_: relationships.id_, - embedding: relationships.embedding, - hash: relationships.hash, + text: dataDict.text, + id_: dataDict.id_, + hash: dataDict.hash, }); } else { throw new Error(`Unknown doc type: ${docType}`);