From eedc14b13ce399f3ae1e86f67537e2255cfb35e6 Mon Sep 17 00:00:00 2001 From: TomPenguin <tom.penguin.zoo@gmail.com> Date: Wed, 25 Oct 2023 12:36:03 +0900 Subject: [PATCH] fix --- apps/simple/vectorIndexCustomize.ts | 11 +++++++- packages/core/src/ChatEngine.ts | 20 +++++++++++--- packages/core/src/QueryEngine.ts | 27 +++++++++++++++---- .../core/src/indices/BaseNodePostprocessor.ts | 21 +++++++++++++++ packages/core/src/indices/index.ts | 1 + .../src/indices/keyword/KeywordTableIndex.ts | 4 +++ .../core/src/indices/summary/SummaryIndex.ts | 11 ++++++-- .../indices/vectorStore/VectorStoreIndex.ts | 4 +++ 8 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 packages/core/src/indices/BaseNodePostprocessor.ts diff --git a/apps/simple/vectorIndexCustomize.ts b/apps/simple/vectorIndexCustomize.ts index 5ad55cff6..b24e91416 100644 --- a/apps/simple/vectorIndexCustomize.ts +++ b/apps/simple/vectorIndexCustomize.ts @@ -3,6 +3,7 @@ import { OpenAI, RetrieverQueryEngine, serviceContextFromDefaults, + SimilarityPostprocessor, VectorStoreIndex, } from "llamaindex"; import essay from "./essay"; @@ -21,8 +22,16 @@ async function main() { const retriever = index.asRetriever(); retriever.similarityTopK = 5; + const nodePostprocessor = new SimilarityPostprocessor({ + similarityCutoff: 0.7, + }); // TODO: cannot pass responseSynthesizer into retriever query engine - const queryEngine = new RetrieverQueryEngine(retriever); + const queryEngine = new RetrieverQueryEngine( + retriever, + undefined, + undefined, + [nodePostprocessor], + ); const response = await queryEngine.query( "What did the author do growing up?", diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 84e9e1ae3..be28e9be8 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -1,6 +1,7 @@ import { v4 as uuidv4 } from "uuid"; import { Event } from "./callbacks/CallbackManager"; import { ChatHistory } from "./ChatHistory"; +import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; import { NodeWithScore, TextNode } from "./Node"; import { @@ -178,14 +179,24 @@ export interface ContextGenerator { export class DefaultContextGenerator implements ContextGenerator { retriever: BaseRetriever; contextSystemPrompt: ContextSystemPrompt; + nodePostprocessors: BaseNodePostprocessor[]; constructor(init: { retriever: BaseRetriever; contextSystemPrompt?: ContextSystemPrompt; + nodePostprocessors?: BaseNodePostprocessor[]; }) { this.retriever = init.retriever; this.contextSystemPrompt = init?.contextSystemPrompt ?? defaultContextSystemPrompt; + this.nodePostprocessors = init.nodePostprocessors || []; + } + + private applyNodePostprocessors(nodes: NodeWithScore[]) { + return this.nodePostprocessors.reduce( + (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), + nodes, + ); } async generate(message: string, parentEvent?: Event): Promise<Context> { @@ -201,16 +212,16 @@ export class DefaultContextGenerator implements ContextGenerator { parentEvent, ); + const nodes = this.applyNodePostprocessors(sourceNodesWithScore); + return { message: { content: this.contextSystemPrompt({ - context: sourceNodesWithScore - .map((r) => (r.node as TextNode).text) - .join("\n\n"), + context: nodes.map((r) => (r.node as TextNode).text).join("\n\n"), }), role: "system", }, - nodes: sourceNodesWithScore, + nodes, }; } } @@ -230,6 +241,7 @@ export class ContextChatEngine implements ChatEngine { chatModel?: LLM; chatHistory?: ChatMessage[]; contextSystemPrompt?: ContextSystemPrompt; + nodePostprocessors?: BaseNodePostprocessor[]; }) { this.chatModel = init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" }); diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index daad1e6b3..abfb52d81 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,5 +1,6 @@ import { v4 as uuidv4 } from "uuid"; import { Event } from "./callbacks/CallbackManager"; +import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; import { NodeWithScore, TextNode } from "./Node"; import { BaseQuestionGenerator, @@ -30,12 +31,14 @@ export interface BaseQueryEngine { export class RetrieverQueryEngine implements BaseQueryEngine { retriever: BaseRetriever; responseSynthesizer: ResponseSynthesizer; + nodePostprocessors: BaseNodePostprocessor[]; preFilters?: unknown; constructor( retriever: BaseRetriever, responseSynthesizer?: ResponseSynthesizer, preFilters?: unknown, + nodePostprocessors?: BaseNodePostprocessor[], ) { this.retriever = retriever; const serviceContext: ServiceContext | undefined = @@ -43,6 +46,24 @@ export class RetrieverQueryEngine implements BaseQueryEngine { this.responseSynthesizer = responseSynthesizer || new ResponseSynthesizer({ serviceContext }); this.preFilters = preFilters; + this.nodePostprocessors = nodePostprocessors || []; + } + + private applyNodePostprocessors(nodes: NodeWithScore[]) { + return this.nodePostprocessors.reduce( + (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes), + nodes, + ); + } + + private async retrieve(query: string, parentEvent: Event) { + const nodes = await this.retriever.retrieve( + query, + parentEvent, + this.preFilters, + ); + + return this.applyNodePostprocessors(nodes); } async query(query: string, parentEvent?: Event) { @@ -51,11 +72,7 @@ export class RetrieverQueryEngine implements BaseQueryEngine { type: "wrapper", tags: ["final"], }; - const nodes = await this.retriever.retrieve( - query, - _parentEvent, - this.preFilters, - ); + const nodes = await this.retrieve(query, _parentEvent); return this.responseSynthesizer.synthesize(query, nodes, _parentEvent); } } diff --git a/packages/core/src/indices/BaseNodePostprocessor.ts b/packages/core/src/indices/BaseNodePostprocessor.ts new file mode 100644 index 000000000..7b408865b --- /dev/null +++ b/packages/core/src/indices/BaseNodePostprocessor.ts @@ -0,0 +1,21 @@ +import { NodeWithScore } from "../Node"; + +export interface BaseNodePostprocessor { + postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[]; +} + +export class SimilarityPostprocessor implements BaseNodePostprocessor { + similarityCutoff?: number; + + constructor(options?: { similarityCutoff?: number }) { + this.similarityCutoff = options?.similarityCutoff; + } + + postprocessNodes(nodes: NodeWithScore[]) { + if (this.similarityCutoff === undefined) return nodes; + + const cutoff = this.similarityCutoff || 0; + console.log(nodes); + return nodes.filter((node) => node.score && node.score >= cutoff); + } +} diff --git a/packages/core/src/indices/index.ts b/packages/core/src/indices/index.ts index 8bda05b2d..ddfe185dc 100644 --- a/packages/core/src/indices/index.ts +++ b/packages/core/src/indices/index.ts @@ -1,4 +1,5 @@ export * from "./BaseIndex"; +export * from "./BaseNodePostprocessor"; export * from "./keyword"; export * from "./summary"; export * from "./vectorStore"; diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts index a305cc2c9..91de6201e 100644 --- a/packages/core/src/indices/keyword/KeywordTableIndex.ts +++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts @@ -15,6 +15,7 @@ import { IndexStructType, KeywordTable, } from "../BaseIndex"; +import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { KeywordTableLLMRetriever, KeywordTableRAKERetriever, @@ -129,11 +130,14 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { asQueryEngine(options?: { retriever?: BaseRetriever; responseSynthesizer?: ResponseSynthesizer; + nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { const { retriever, responseSynthesizer } = options ?? {}; return new RetrieverQueryEngine( retriever ?? this.asRetriever(), responseSynthesizer, + undefined, + options?.nodePostprocessors, ); } diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts index 39a8ec525..91b12ba9b 100644 --- a/packages/core/src/indices/summary/SummaryIndex.ts +++ b/packages/core/src/indices/summary/SummaryIndex.ts @@ -10,17 +10,18 @@ import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; import { StorageContext, storageContextFromDefaults, } from "../../storage/StorageContext"; -import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; import { BaseIndex, BaseIndexInit, IndexList, IndexStructType, } from "../BaseIndex"; +import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { SummaryIndexLLMRetriever, SummaryIndexRetriever, @@ -155,6 +156,7 @@ export class SummaryIndex extends BaseIndex<IndexList> { asQueryEngine(options?: { retriever?: BaseRetriever; responseSynthesizer?: ResponseSynthesizer; + nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { let { retriever, responseSynthesizer } = options ?? {}; @@ -170,7 +172,12 @@ export class SummaryIndex extends BaseIndex<IndexList> { }); } - return new RetrieverQueryEngine(retriever, responseSynthesizer); + return new RetrieverQueryEngine( + retriever, + responseSynthesizer, + undefined, + options?.nodePostprocessors, + ); } static async buildIndexFromNodes( diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 08499e5e9..ff34df502 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -18,6 +18,7 @@ import { IndexDict, IndexStructType, } from "../BaseIndex"; +import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { VectorIndexRetriever } from "./VectorIndexRetriever"; export interface VectorIndexOptions { @@ -246,11 +247,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { asQueryEngine(options?: { retriever?: BaseRetriever; responseSynthesizer?: ResponseSynthesizer; + nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { const { retriever, responseSynthesizer } = options ?? {}; return new RetrieverQueryEngine( retriever ?? this.asRetriever(), responseSynthesizer, + undefined, + options?.nodePostprocessors, ); } -- GitLab