diff --git a/packages/core/src/index/list/ListIndexRetriever.ts b/packages/core/src/index/list/ListIndexRetriever.ts index 59760f2110e68083ec11570af22e0beb8551bfd7..86b0a7827c5b6fd36d6ade64f7209b1c13b3611b 100644 --- a/packages/core/src/index/list/ListIndexRetriever.ts +++ b/packages/core/src/index/list/ListIndexRetriever.ts @@ -3,10 +3,13 @@ import { NodeWithScore } from "../../Node"; import { ListIndex } from "./ListIndex"; import { ServiceContext } from "../../ServiceContext"; import { + NodeFormatterFunction, + ChoiceSelectParserFunction, defaultFormatNodeBatchFn, defaultParseChoiceSelectAnswerFn, } from "./utils"; import { SimplePrompt, defaultChoiceSelectPrompt } from "../../Prompt"; +import _ from "lodash"; /** * Simple retriever for ListIndex that returns all nodes @@ -34,16 +37,16 @@ export class ListIndexLLMRetriever implements BaseRetriever { index: ListIndex; choiceSelectPrompt: SimplePrompt; choiceBatchSize: number; - formatNodeBatchFn: Function; - parseChoiceSelectAnswerFn: Function; + formatNodeBatchFn: NodeFormatterFunction; + parseChoiceSelectAnswerFn: ChoiceSelectParserFunction; serviceContext: ServiceContext; constructor( index: ListIndex, choiceSelectPrompt?: SimplePrompt, choiceBatchSize: number = 10, - formatNodeBatchFn?: Function, - parseChoiceSelectAnswerFn?: Function, + formatNodeBatchFn?: NodeFormatterFunction, + parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction, serviceContext?: ServiceContext ) { this.index = index; @@ -70,23 +73,19 @@ export class ListIndexLLMRetriever implements BaseRetriever { input ); - const [rawChoices, relevances] = this.parseChoiceSelectAnswerFn( + // parseResult is a map from doc number to relevance score + const parseResult = this.parseChoiceSelectAnswerFn( rawResponse, nodesBatch.length ); - const choiceIndexes = rawChoices.map( - (choice: string) => parseInt(choice) - 1 - ); - const choiceNodeIds = choiceIndexes.map( - (idx: number) => nodeIdsBatch[idx] - ); + const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => { + return `${idx}` in parseResult; + }); const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds); - const relevancesFilled = - relevances || new Array(choiceNodes.length).fill(1.0); const nodeWithScores = choiceNodes.map((node, i) => ({ node: node, - score: relevancesFilled[i], + score: _.get(parseResult, `${i + 1}`, 1), })); results.push(...nodeWithScores); diff --git a/packages/core/src/index/list/utils.ts b/packages/core/src/index/list/utils.ts index 33a9a3ce50ab2d8d0b92a146e000ef006f411676..b7a1d3f8fddc0af2c587ef1db2c8ff0db01dbfce 100644 --- a/packages/core/src/index/list/utils.ts +++ b/packages/core/src/index/list/utils.ts @@ -1,3 +1,73 @@ -export function defaultFormatNodeBatchFn() {} +import { BaseNode, MetadataMode } from "../../Node"; +import _ from "lodash"; -export function defaultParseChoiceSelectAnswerFn() {} +export type NodeFormatterFunction = (summaryNodes: BaseNode[]) => string; +export const defaultFormatNodeBatchFn: NodeFormatterFunction = ( + summaryNodes: BaseNode[] +): string => { + return summaryNodes + .map((node, idx) => { + return ` +Document ${idx + 1}: +${node.getContent(MetadataMode.LLM)} + `.trim(); + }) + .join("\n\n"); +}; + +// map from document number to its relevance score +export type ChoiceSelectParseResult = { [docNumber: number]: number }; +export type ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr?: boolean +) => ChoiceSelectParseResult; + +export const defaultParseChoiceSelectAnswerFn: ChoiceSelectParserFunction = ( + answer: string, + numChoices: number, + raiseErr: boolean = false +): ChoiceSelectParseResult => { + // split the line into the answer number and relevance score portions + const lineTokens: string[][] = answer + .split("\n") + .map((line: string) => { + let lineTokens = line.split(","); + if (lineTokens.length !== 2) { + if (raiseErr) { + throw new Error( + `Invalid answer line: ${line}. Answer line must be of the form: answer_num: <int>, answer_relevance: <float>` + ); + } else { + return null; + } + } + return lineTokens; + }) + .filter((lineTokens) => !_.isNil(lineTokens)) as string[][]; + + // parse the answer number and relevance score + return lineTokens.reduce( + (parseResult: ChoiceSelectParseResult, lineToken: string[]) => { + try { + let docNum = parseInt(lineToken[0].split(":")[1].trim()); + let answerRelevance = parseFloat(lineToken[1].split(":")[1].trim()); + if (docNum < 1 || docNum > numChoices) { + if (raiseErr) { + throw new Error( + `Invalid answer number: ${docNum}. Answer number must be between 1 and ${numChoices}` + ); + } else { + parseResult[docNum] = answerRelevance; + } + } + } catch (e) { + if (raiseErr) { + throw e; + } + } + return parseResult; + }, + {} + ); +};