Skip to content
Snippets Groups Projects
Commit b87e6d9c authored by Sourabh Desai's avatar Sourabh Desai
Browse files

finish implementation for llm list index retriever

parent 8d618a6b
No related branches found
No related tags found
No related merge requests found
...@@ -3,10 +3,13 @@ import { NodeWithScore } from "../../Node"; ...@@ -3,10 +3,13 @@ import { NodeWithScore } from "../../Node";
import { ListIndex } from "./ListIndex"; import { ListIndex } from "./ListIndex";
import { ServiceContext } from "../../ServiceContext"; import { ServiceContext } from "../../ServiceContext";
import { import {
NodeFormatterFunction,
ChoiceSelectParserFunction,
defaultFormatNodeBatchFn, defaultFormatNodeBatchFn,
defaultParseChoiceSelectAnswerFn, defaultParseChoiceSelectAnswerFn,
} from "./utils"; } from "./utils";
import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt";
import _ from "lodash";
/** /**
* Simple retriever for ListIndex that returns all nodes * Simple retriever for ListIndex that returns all nodes
...@@ -34,16 +37,16 @@ export class ListIndexLLMRetriever implements BaseRetriever { ...@@ -34,16 +37,16 @@ export class ListIndexLLMRetriever implements BaseRetriever {
index: ListIndex; index: ListIndex;
choiceSelectPrompt: SimplePrompt; choiceSelectPrompt: SimplePrompt;
choiceBatchSize: number; choiceBatchSize: number;
formatNodeBatchFn: Function; formatNodeBatchFn: NodeFormatterFunction;
parseChoiceSelectAnswerFn: Function; parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
serviceContext: ServiceContext; serviceContext: ServiceContext;
constructor( constructor(
index: ListIndex, index: ListIndex,
choiceSelectPrompt?: SimplePrompt, choiceSelectPrompt?: SimplePrompt,
choiceBatchSize: number = 10, choiceBatchSize: number = 10,
formatNodeBatchFn?: Function, formatNodeBatchFn?: NodeFormatterFunction,
parseChoiceSelectAnswerFn?: Function, parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
serviceContext?: ServiceContext serviceContext?: ServiceContext
) { ) {
this.index = index; this.index = index;
...@@ -70,23 +73,19 @@ export class ListIndexLLMRetriever implements BaseRetriever { ...@@ -70,23 +73,19 @@ export class ListIndexLLMRetriever implements BaseRetriever {
input input
); );
const [rawChoices, relevances] = this.parseChoiceSelectAnswerFn( // parseResult is a map from doc number to relevance score
const parseResult = this.parseChoiceSelectAnswerFn(
rawResponse, rawResponse,
nodesBatch.length nodesBatch.length
); );
const choiceIndexes = rawChoices.map( const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => {
(choice: string) => parseInt(choice) - 1 return `${idx}` in parseResult;
); });
const choiceNodeIds = choiceIndexes.map(
(idx: number) => nodeIdsBatch[idx]
);
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds); const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
const relevancesFilled =
relevances || new Array(choiceNodes.length).fill(1.0);
const nodeWithScores = choiceNodes.map((node, i) => ({ const nodeWithScores = choiceNodes.map((node, i) => ({
node: node, node: node,
score: relevancesFilled[i], score: _.get(parseResult, `${i + 1}`, 1),
})); }));
results.push(...nodeWithScores); results.push(...nodeWithScores);
......
export function defaultFormatNodeBatchFn() {} import { BaseNode, MetadataMode } from "../../Node";
import _ from "lodash";
export function defaultParseChoiceSelectAnswerFn() {} export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string;
export const defaultFormatNodeBatchFn: NodeFormatterFunction = (
summaryNodes: BaseNode[]
): string => {
return summaryNodes
.map((node, idx) => {
return `
Document ${idx + 1}:
${node.getContent(MetadataMode.LLM)}
`.trim();
})
.join("\n\n");
};
// map from document number to its relevance score
export type ChoiceSelectParseResult = { [docNumber: number]: number };
export type ChoiceSelectParserFunction = (
answer: string,
numChoices: number,
raiseErr?: boolean
) => ChoiceSelectParseResult;
export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = (
answer: string,
numChoices: number,
raiseErr: boolean = false
): ChoiceSelectParseResult => {
// split the line into the answer number and relevance score portions
const lineTokens: string[][] = answer
.split("\n")
.map((line: string) => {
let lineTokens = line.split(",");
if (lineTokens.length !== 2) {
if (raiseErr) {
throw new Error(
`Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>`
);
} else {
return null;
}
}
return lineTokens;
})
.filter((lineTokens) => !_.isNil(lineTokens)) as string[][];
// parse the answer number and relevance score
return lineTokens.reduce(
(parseResult: ChoiceSelectParseResult, lineToken: string[]) => {
try {
let docNum = parseInt(lineToken[0].split(":")[1].trim());
let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim());
if (docNum < 1 || docNum > numChoices) {
if (raiseErr) {
throw new Error(
`Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}`
);
} else {
parseResult[docNum] = answerRelevance;
}
}
} catch (e) {
if (raiseErr) {
throw e;
}
}
return parseResult;
},
{}
);
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment