From 8d8bee526303ad58edb1965d9775d6c448f526fd Mon Sep 17 00:00:00 2001 From: Sourabh Desai <sourabhdesai@gmail.com> Date: Mon, 3 Jul 2023 06:42:00 +0000 Subject: [PATCH] start implementing list index retrievers --- packages/core/src/BaseIndex.ts | 24 +++++- packages/core/src/ListIndex.ts | 29 ++++--- packages/core/src/ListIndexRetriever.ts | 100 ++++++++++++++++++++++++ packages/core/src/Node.ts | 2 +- packages/core/src/Retriever.ts | 1 - 5 files changed, 136 insertions(+), 20 deletions(-) create mode 100644 packages/core/src/ListIndexRetriever.ts diff --git a/packages/core/src/BaseIndex.ts b/packages/core/src/BaseIndex.ts index 2eb26544e..65452fbd9 100644 --- a/packages/core/src/BaseIndex.ts +++ b/packages/core/src/BaseIndex.ts @@ -10,11 +10,9 @@ import { import { BaseDocumentStore } from "./storage/docStore/types"; import { VectorStore } from "./storage/vectorStore/types"; -export class IndexDict { +export abstract class IndexStruct { indexId: string; summary?: string; - nodesDict: Record<string, BaseNode> = {}; - docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext constructor(indexId = uuidv4(), summary = undefined) { this.indexId = indexId; @@ -27,6 +25,18 @@ export class IndexDict { } return this.summary; } +} + +export class IndexDict extends IndexStruct { + nodesDict: Record<string, BaseNode> = {}; + docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext + + getSummary(): string { + if (this.summary === undefined) { + throw new Error("summary field of the index dict is not set"); + } + return this.summary; + } addNode(node: BaseNode, textId?: string) { const vectorId = textId ?? node.id_; @@ -34,6 +44,14 @@ export class IndexDict { } } +export class IndexList extends IndexStruct { + nodes: string[] = []; + + addNode(node: BaseNode) { + this.nodes.push(node.id_); + } +} + export interface BaseIndexInit<T> { serviceContext: ServiceContext; storageContext: StorageContext; diff --git a/packages/core/src/ListIndex.ts b/packages/core/src/ListIndex.ts index e8e1f0810..ebfd73834 100644 --- a/packages/core/src/ListIndex.ts +++ b/packages/core/src/ListIndex.ts @@ -1,22 +1,21 @@ import { BaseNode } from "./Node"; -import { BaseIndex, BaseIndexInit } from "./BaseIndex"; -import { IndexList } from "./dataStructs/IndexList"; +import { BaseIndex, BaseIndexInit, IndexList } from "./BaseIndex"; import { BaseRetriever } from "./Retriever"; -import { ListIndexRetriever } from "./retrievers/ListIndexRetriever"; -import { ListIndexEmbeddingRetriever } from "./retrievers/ListIndexEmbeddingRetriever"; -import { ListIndexLLMRetriever } from "./retrievers/ListIndexLLMRetriever"; +import { ListIndexRetriever } from "./ListIndexRetriever"; import { ServiceContext } from "./ServiceContext"; +import { RefDocInfo } from "./storage/docStore/types"; +import _ from "lodash"; export enum ListRetrieverMode { DEFAULT = "default", - EMBEDDING = "embedding", + // EMBEDDING = "embedding", LLM = "llm", } export interface ListIndexInit extends BaseIndexInit<IndexList> { nodes?: BaseNode[]; - indexStruct?: IndexList; - serviceContext?: ServiceContext; + indexStruct: IndexList; + serviceContext: ServiceContext; } export class ListIndex extends BaseIndex<IndexList> { @@ -30,10 +29,6 @@ export class ListIndex extends BaseIndex<IndexList> { switch (mode) { case ListRetrieverMode.DEFAULT: return new ListIndexRetriever(this); - case ListRetrieverMode.EMBEDDING: - throw new Error( - `Support for Embedding retriever mode is not implemented` - ); case ListRetrieverMode.LLM: throw new Error(`Support for LLM retriever mode is not implemented`); default: @@ -71,11 +66,15 @@ export class ListIndex extends BaseIndex<IndexList> { for (const node of nodes) { const refNode = node.sourceNode; - if (!refNode) continue; + if (_.isNil(refNode)) { + continue; + } - const refDocInfo = this.docStore.getRefDocInfo(refNode.nodeId); + const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId); - if (!refDocInfo) continue; + if (_.isNil(refDocInfo)) { + continue; + } refDocInfoMap[refNode.nodeId] = refDocInfo; } diff --git a/packages/core/src/ListIndexRetriever.ts b/packages/core/src/ListIndexRetriever.ts new file mode 100644 index 000000000..744c11d48 --- /dev/null +++ b/packages/core/src/ListIndexRetriever.ts @@ -0,0 +1,100 @@ +import { BaseRetriever } from "./Retriever"; +import { NodeWithScore } from "./Node"; +import { ListIndex } from "./ListIndex"; +import { ServiceContext } from "./ServiceContext"; +import { + ChoiceSelectPrompt, + DEFAULT_CHOICE_SELECT_PROMPT, +} from "./ChoiceSelectPrompt"; +import { + defaultFormatNodeBatchFn, + defaultParseChoiceSelectAnswerFn, +} from "./Utils"; + +/** + * Simple retriever for ListIndex that returns all nodes + */ +export class ListIndexRetriever implements BaseRetriever { + index: ListIndex; + + constructor(index: ListIndex) { + this.index = index; + } + + async aretrieve(query: string): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const nodes = await this.index.docStore.getNodes(nodeIds); + return nodes.map((node) => ({ + node: node, + })); + } +} + +/** + * LLM retriever for ListIndex. + */ +export class ListIndexLLMRetriever implements BaseRetriever { + index: ListIndex; + choiceSelectPrompt: ChoiceSelectPrompt; + choiceBatchSize: number; + formatNodeBatchFn: Function; + parseChoiceSelectAnswerFn: Function; + serviceContext: ServiceContext; + + constructor( + index: ListIndex, + choiceSelectPrompt?: ChoiceSelectPrompt, + choiceBatchSize: number = 10, + formatNodeBatchFn?: Function, + parseChoiceSelectAnswerFn?: Function, + serviceContext?: ServiceContext + ) { + this.index = index; + this.choiceSelectPrompt = + choiceSelectPrompt || DEFAULT_CHOICE_SELECT_PROMPT; + this.choiceBatchSize = choiceBatchSize; + this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn; + this.parseChoiceSelectAnswerFn = + parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn; + this.serviceContext = serviceContext || index.serviceContext; + } + + async aretrieve(query: string): Promise<NodeWithScore[]> { + const nodeIds = this.index.indexStruct.nodes; + const results: NodeWithScore[] = []; + + for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) { + const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize); + const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch); + + const fmtBatchStr = this.formatNodeBatchFn(nodesBatch); + const rawResponse = await this.serviceContext.llmPredictor.apredict( + this.choiceSelectPrompt, + fmtBatchStr, + query + ); + + const [rawChoices, relevances] = this.parseChoiceSelectAnswerFn( + rawResponse, + nodesBatch.length + ); + const choiceIndexes = rawChoices.map( + (choice: string) => parseInt(choice) - 1 + ); + const choiceNodeIds = choiceIndexes.map( + (idx: number) => nodeIdsBatch[idx] + ); + + const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds); + const relevancesFilled = + relevances || new Array(choiceNodes.length).fill(1.0); + const nodeWithScores = choiceNodes.map((node, i) => ({ + node: node, + score: relevancesFilled[i], + })); + + results.push(...nodeWithScores); + } + return results; + } +} diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index 95d3081c8..970ddd2ce 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -227,7 +227,7 @@ export class ImageDocument extends Document { export interface NodeWithScore { node: BaseNode; - score: number; + score?: number; } export interface NodeWithEmbedding { diff --git a/packages/core/src/Retriever.ts b/packages/core/src/Retriever.ts index d1c942254..dfb72e88f 100644 --- a/packages/core/src/Retriever.ts +++ b/packages/core/src/Retriever.ts @@ -1,5 +1,4 @@ import { VectorStoreIndex } from "./BaseIndex"; -import { BaseEmbedding, getTopKEmbeddings } from "./Embedding"; import { NodeWithScore } from "./Node"; import { ServiceContext } from "./ServiceContext"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; -- GitLab