Skip to content
Snippets Groups Projects
Commit 8d8bee52 authored by Sourabh Desai's avatar Sourabh Desai
Browse files

start implementing list index retrievers

parent ed924641
Branches
Tags
No related merge requests found
...@@ -10,11 +10,9 @@ import { ...@@ -10,11 +10,9 @@ import {
import { BaseDocumentStore } from "./storage/docStore/types"; import { BaseDocumentStore } from "./storage/docStore/types";
import { VectorStore } from "./storage/vectorStore/types"; import { VectorStore } from "./storage/vectorStore/types";
export class IndexDict { export abstract class IndexStruct {
indexId: string; indexId: string;
summary?: string; summary?: string;
nodesDict: Record<string, BaseNode> = {};
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
constructor(indexId = uuidv4(), summary = undefined) { constructor(indexId = uuidv4(), summary = undefined) {
this.indexId = indexId; this.indexId = indexId;
...@@ -27,6 +25,18 @@ export class IndexDict { ...@@ -27,6 +25,18 @@ export class IndexDict {
} }
return this.summary; return this.summary;
} }
}
export class IndexDict extends IndexStruct {
nodesDict: Record<string, BaseNode> = {};
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
getSummary(): string {
if (this.summary === undefined) {
throw new Error("summary field of the index dict is not set");
}
return this.summary;
}
addNode(node: BaseNode, textId?: string) { addNode(node: BaseNode, textId?: string) {
const vectorId = textId ?? node.id_; const vectorId = textId ?? node.id_;
...@@ -34,6 +44,14 @@ export class IndexDict { ...@@ -34,6 +44,14 @@ export class IndexDict {
} }
} }
export class IndexList extends IndexStruct {
nodes: string[] = [];
addNode(node: BaseNode) {
this.nodes.push(node.id_);
}
}
export interface BaseIndexInit<T> { export interface BaseIndexInit<T> {
serviceContext: ServiceContext; serviceContext: ServiceContext;
storageContext: StorageContext; storageContext: StorageContext;
......
import { BaseNode } from "./Node"; import { BaseNode } from "./Node";
import { BaseIndex, BaseIndexInit } from "./BaseIndex"; import { BaseIndex, BaseIndexInit, IndexList } from "./BaseIndex";
import { IndexList } from "./dataStructs/IndexList";
import { BaseRetriever } from "./Retriever"; import { BaseRetriever } from "./Retriever";
import { ListIndexRetriever } from "./retrievers/ListIndexRetriever"; import { ListIndexRetriever } from "./ListIndexRetriever";
import { ListIndexEmbeddingRetriever } from "./retrievers/ListIndexEmbeddingRetriever";
import { ListIndexLLMRetriever } from "./retrievers/ListIndexLLMRetriever";
import { ServiceContext } from "./ServiceContext"; import { ServiceContext } from "./ServiceContext";
import { RefDocInfo } from "./storage/docStore/types";
import _ from "lodash";
export enum ListRetrieverMode { export enum ListRetrieverMode {
DEFAULT = "default", DEFAULT = "default",
EMBEDDING = "embedding", // EMBEDDING = "embedding",
LLM = "llm", LLM = "llm",
} }
export interface ListIndexInit extends BaseIndexInit<IndexList> { export interface ListIndexInit extends BaseIndexInit<IndexList> {
nodes?: BaseNode[]; nodes?: BaseNode[];
indexStruct?: IndexList; indexStruct: IndexList;
serviceContext?: ServiceContext; serviceContext: ServiceContext;
} }
export class ListIndex extends BaseIndex<IndexList> { export class ListIndex extends BaseIndex<IndexList> {
...@@ -30,10 +29,6 @@ export class ListIndex extends BaseIndex<IndexList> { ...@@ -30,10 +29,6 @@ export class ListIndex extends BaseIndex<IndexList> {
switch (mode) { switch (mode) {
case ListRetrieverMode.DEFAULT: case ListRetrieverMode.DEFAULT:
return new ListIndexRetriever(this); return new ListIndexRetriever(this);
case ListRetrieverMode.EMBEDDING:
throw new Error(
`Support for Embedding retriever mode is not implemented`
);
case ListRetrieverMode.LLM: case ListRetrieverMode.LLM:
throw new Error(`Support for LLM retriever mode is not implemented`); throw new Error(`Support for LLM retriever mode is not implemented`);
default: default:
...@@ -71,11 +66,15 @@ export class ListIndex extends BaseIndex<IndexList> { ...@@ -71,11 +66,15 @@ export class ListIndex extends BaseIndex<IndexList> {
for (const node of nodes) { for (const node of nodes) {
const refNode = node.sourceNode; const refNode = node.sourceNode;
if (!refNode) continue; if (_.isNil(refNode)) {
continue;
}
const refDocInfo = this.docStore.getRefDocInfo(refNode.nodeId); const refDocInfo = await this.docStore.getRefDocInfo(refNode.nodeId);
if (!refDocInfo) continue; if (_.isNil(refDocInfo)) {
continue;
}
refDocInfoMap[refNode.nodeId] = refDocInfo; refDocInfoMap[refNode.nodeId] = refDocInfo;
} }
......
import { BaseRetriever } from "./Retriever";
import { NodeWithScore } from "./Node";
import { ListIndex } from "./ListIndex";
import { ServiceContext } from "./ServiceContext";
import {
ChoiceSelectPrompt,
DEFAULT_CHOICE_SELECT_PROMPT,
} from "./ChoiceSelectPrompt";
import {
defaultFormatNodeBatchFn,
defaultParseChoiceSelectAnswerFn,
} from "./Utils";
/**
* Simple retriever for ListIndex that returns all nodes
*/
export class ListIndexRetriever implements BaseRetriever {
index: ListIndex;
constructor(index: ListIndex) {
this.index = index;
}
async aretrieve(query: string): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const nodes = await this.index.docStore.getNodes(nodeIds);
return nodes.map((node) => ({
node: node,
}));
}
}
/**
* LLM retriever for ListIndex.
*/
export class ListIndexLLMRetriever implements BaseRetriever {
index: ListIndex;
choiceSelectPrompt: ChoiceSelectPrompt;
choiceBatchSize: number;
formatNodeBatchFn: Function;
parseChoiceSelectAnswerFn: Function;
serviceContext: ServiceContext;
constructor(
index: ListIndex,
choiceSelectPrompt?: ChoiceSelectPrompt,
choiceBatchSize: number = 10,
formatNodeBatchFn?: Function,
parseChoiceSelectAnswerFn?: Function,
serviceContext?: ServiceContext
) {
this.index = index;
this.choiceSelectPrompt =
choiceSelectPrompt || DEFAULT_CHOICE_SELECT_PROMPT;
this.choiceBatchSize = choiceBatchSize;
this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn;
this.parseChoiceSelectAnswerFn =
parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn;
this.serviceContext = serviceContext || index.serviceContext;
}
async aretrieve(query: string): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const results: NodeWithScore[] = [];
for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) {
const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize);
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
const rawResponse = await this.serviceContext.llmPredictor.apredict(
this.choiceSelectPrompt,
fmtBatchStr,
query
);
const [rawChoices, relevances] = this.parseChoiceSelectAnswerFn(
rawResponse,
nodesBatch.length
);
const choiceIndexes = rawChoices.map(
(choice: string) => parseInt(choice) - 1
);
const choiceNodeIds = choiceIndexes.map(
(idx: number) => nodeIdsBatch[idx]
);
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
const relevancesFilled =
relevances || new Array(choiceNodes.length).fill(1.0);
const nodeWithScores = choiceNodes.map((node, i) => ({
node: node,
score: relevancesFilled[i],
}));
results.push(...nodeWithScores);
}
return results;
}
}
...@@ -227,7 +227,7 @@ export class ImageDocument extends Document { ...@@ -227,7 +227,7 @@ export class ImageDocument extends Document {
export interface NodeWithScore { export interface NodeWithScore {
node: BaseNode; node: BaseNode;
score: number; score?: number;
} }
export interface NodeWithEmbedding { export interface NodeWithEmbedding {
......
import { VectorStoreIndex } from "./BaseIndex"; import { VectorStoreIndex } from "./BaseIndex";
import { BaseEmbedding, getTopKEmbeddings } from "./Embedding";
import { NodeWithScore } from "./Node"; import { NodeWithScore } from "./Node";
import { ServiceContext } from "./ServiceContext"; import { ServiceContext } from "./ServiceContext";
import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment