From eedc14b13ce399f3ae1e86f67537e2255cfb35e6 Mon Sep 17 00:00:00 2001
From: TomPenguin <tom.penguin.zoo@gmail.com>
Date: Wed, 25 Oct 2023 12:36:03 +0900
Subject: [PATCH] fix

---
 apps/simple/vectorIndexCustomize.ts           | 11 +++++++-
 packages/core/src/ChatEngine.ts               | 20 +++++++++++---
 packages/core/src/QueryEngine.ts              | 27 +++++++++++++++----
 .../core/src/indices/BaseNodePostprocessor.ts | 21 +++++++++++++++
 packages/core/src/indices/index.ts            |  1 +
 .../src/indices/keyword/KeywordTableIndex.ts  |  4 +++
 .../core/src/indices/summary/SummaryIndex.ts  | 11 ++++++--
 .../indices/vectorStore/VectorStoreIndex.ts   |  4 +++
 8 files changed, 87 insertions(+), 12 deletions(-)
 create mode 100644 packages/core/src/indices/BaseNodePostprocessor.ts

diff --git a/apps/simple/vectorIndexCustomize.ts b/apps/simple/vectorIndexCustomize.ts
index 5ad55cff6..b24e91416 100644
--- a/apps/simple/vectorIndexCustomize.ts
+++ b/apps/simple/vectorIndexCustomize.ts
@@ -3,6 +3,7 @@ import {
   OpenAI,
   RetrieverQueryEngine,
   serviceContextFromDefaults,
+  SimilarityPostprocessor,
   VectorStoreIndex,
 } from "llamaindex";
 import essay from "./essay";
@@ -21,8 +22,16 @@ async function main() {
 
   const retriever = index.asRetriever();
   retriever.similarityTopK = 5;
+  const nodePostprocessor = new SimilarityPostprocessor({
+    similarityCutoff: 0.7,
+  });
   // TODO: cannot pass responseSynthesizer into retriever query engine
-  const queryEngine = new RetrieverQueryEngine(retriever);
+  const queryEngine = new RetrieverQueryEngine(
+    retriever,
+    undefined,
+    undefined,
+    [nodePostprocessor],
+  );
 
   const response = await queryEngine.query(
     "What did the author do growing up?",
diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index 84e9e1ae3..be28e9be8 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -1,6 +1,7 @@
 import { v4 as uuidv4 } from "uuid";
 import { Event } from "./callbacks/CallbackManager";
 import { ChatHistory } from "./ChatHistory";
+import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
 import { ChatMessage, LLM, OpenAI } from "./llm/LLM";
 import { NodeWithScore, TextNode } from "./Node";
 import {
@@ -178,14 +179,24 @@ export interface ContextGenerator {
 export class DefaultContextGenerator implements ContextGenerator {
   retriever: BaseRetriever;
   contextSystemPrompt: ContextSystemPrompt;
+  nodePostprocessors: BaseNodePostprocessor[];
 
   constructor(init: {
     retriever: BaseRetriever;
     contextSystemPrompt?: ContextSystemPrompt;
+    nodePostprocessors?: BaseNodePostprocessor[];
   }) {
     this.retriever = init.retriever;
     this.contextSystemPrompt =
       init?.contextSystemPrompt ?? defaultContextSystemPrompt;
+    this.nodePostprocessors = init.nodePostprocessors || [];
+  }
+
+  private applyNodePostprocessors(nodes: NodeWithScore[]) {
+    return this.nodePostprocessors.reduce(
+      (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes),
+      nodes,
+    );
   }
 
   async generate(message: string, parentEvent?: Event): Promise<Context> {
@@ -201,16 +212,16 @@ export class DefaultContextGenerator implements ContextGenerator {
       parentEvent,
     );
 
+    const nodes = this.applyNodePostprocessors(sourceNodesWithScore);
+
     return {
       message: {
         content: this.contextSystemPrompt({
-          context: sourceNodesWithScore
-            .map((r) => (r.node as TextNode).text)
-            .join("\n\n"),
+          context: nodes.map((r) => (r.node as TextNode).text).join("\n\n"),
         }),
         role: "system",
       },
-      nodes: sourceNodesWithScore,
+      nodes,
     };
   }
 }
@@ -230,6 +241,7 @@ export class ContextChatEngine implements ChatEngine {
     chatModel?: LLM;
     chatHistory?: ChatMessage[];
     contextSystemPrompt?: ContextSystemPrompt;
+    nodePostprocessors?: BaseNodePostprocessor[];
   }) {
     this.chatModel =
       init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" });
diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts
index daad1e6b3..abfb52d81 100644
--- a/packages/core/src/QueryEngine.ts
+++ b/packages/core/src/QueryEngine.ts
@@ -1,5 +1,6 @@
 import { v4 as uuidv4 } from "uuid";
 import { Event } from "./callbacks/CallbackManager";
+import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
 import { NodeWithScore, TextNode } from "./Node";
 import {
   BaseQuestionGenerator,
@@ -30,12 +31,14 @@ export interface BaseQueryEngine {
 export class RetrieverQueryEngine implements BaseQueryEngine {
   retriever: BaseRetriever;
   responseSynthesizer: ResponseSynthesizer;
+  nodePostprocessors: BaseNodePostprocessor[];
   preFilters?: unknown;
 
   constructor(
     retriever: BaseRetriever,
     responseSynthesizer?: ResponseSynthesizer,
     preFilters?: unknown,
+    nodePostprocessors?: BaseNodePostprocessor[],
   ) {
     this.retriever = retriever;
     const serviceContext: ServiceContext | undefined =
@@ -43,6 +46,24 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
     this.responseSynthesizer =
       responseSynthesizer || new ResponseSynthesizer({ serviceContext });
     this.preFilters = preFilters;
+    this.nodePostprocessors = nodePostprocessors || [];
+  }
+
+  private applyNodePostprocessors(nodes: NodeWithScore[]) {
+    return this.nodePostprocessors.reduce(
+      (nodes, nodePostprocessor) => nodePostprocessor.postprocessNodes(nodes),
+      nodes,
+    );
+  }
+
+  private async retrieve(query: string, parentEvent: Event) {
+    const nodes = await this.retriever.retrieve(
+      query,
+      parentEvent,
+      this.preFilters,
+    );
+
+    return this.applyNodePostprocessors(nodes);
   }
 
   async query(query: string, parentEvent?: Event) {
@@ -51,11 +72,7 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
       type: "wrapper",
       tags: ["final"],
     };
-    const nodes = await this.retriever.retrieve(
-      query,
-      _parentEvent,
-      this.preFilters,
-    );
+    const nodes = await this.retrieve(query, _parentEvent);
     return this.responseSynthesizer.synthesize(query, nodes, _parentEvent);
   }
 }
diff --git a/packages/core/src/indices/BaseNodePostprocessor.ts b/packages/core/src/indices/BaseNodePostprocessor.ts
new file mode 100644
index 000000000..7b408865b
--- /dev/null
+++ b/packages/core/src/indices/BaseNodePostprocessor.ts
@@ -0,0 +1,21 @@
+import { NodeWithScore } from "../Node";
+
+export interface BaseNodePostprocessor {
+  postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[];
+}
+
+export class SimilarityPostprocessor implements BaseNodePostprocessor {
+  similarityCutoff?: number;
+
+  constructor(options?: { similarityCutoff?: number }) {
+    this.similarityCutoff = options?.similarityCutoff;
+  }
+
+  postprocessNodes(nodes: NodeWithScore[]) {
+    if (this.similarityCutoff === undefined) return nodes;
+
+    const cutoff = this.similarityCutoff || 0;
+    console.log(nodes);
+    return nodes.filter((node) => node.score && node.score >= cutoff);
+  }
+}
diff --git a/packages/core/src/indices/index.ts b/packages/core/src/indices/index.ts
index 8bda05b2d..ddfe185dc 100644
--- a/packages/core/src/indices/index.ts
+++ b/packages/core/src/indices/index.ts
@@ -1,4 +1,5 @@
 export * from "./BaseIndex";
+export * from "./BaseNodePostprocessor";
 export * from "./keyword";
 export * from "./summary";
 export * from "./vectorStore";
diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts
index a305cc2c9..91de6201e 100644
--- a/packages/core/src/indices/keyword/KeywordTableIndex.ts
+++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts
@@ -15,6 +15,7 @@ import {
   IndexStructType,
   KeywordTable,
 } from "../BaseIndex";
+import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import {
   KeywordTableLLMRetriever,
   KeywordTableRAKERetriever,
@@ -129,11 +130,14 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> {
   asQueryEngine(options?: {
     retriever?: BaseRetriever;
     responseSynthesizer?: ResponseSynthesizer;
+    nodePostprocessors?: BaseNodePostprocessor[];
   }): BaseQueryEngine {
     const { retriever, responseSynthesizer } = options ?? {};
     return new RetrieverQueryEngine(
       retriever ?? this.asRetriever(),
       responseSynthesizer,
+      undefined,
+      options?.nodePostprocessors,
     );
   }
 
diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts
index 39a8ec525..91b12ba9b 100644
--- a/packages/core/src/indices/summary/SummaryIndex.ts
+++ b/packages/core/src/indices/summary/SummaryIndex.ts
@@ -10,17 +10,18 @@ import {
   ServiceContext,
   serviceContextFromDefaults,
 } from "../../ServiceContext";
+import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
 import {
   StorageContext,
   storageContextFromDefaults,
 } from "../../storage/StorageContext";
-import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
 import {
   BaseIndex,
   BaseIndexInit,
   IndexList,
   IndexStructType,
 } from "../BaseIndex";
+import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import {
   SummaryIndexLLMRetriever,
   SummaryIndexRetriever,
@@ -155,6 +156,7 @@ export class SummaryIndex extends BaseIndex<IndexList> {
   asQueryEngine(options?: {
     retriever?: BaseRetriever;
     responseSynthesizer?: ResponseSynthesizer;
+    nodePostprocessors?: BaseNodePostprocessor[];
   }): BaseQueryEngine {
     let { retriever, responseSynthesizer } = options ?? {};
 
@@ -170,7 +172,12 @@ export class SummaryIndex extends BaseIndex<IndexList> {
       });
     }
 
-    return new RetrieverQueryEngine(retriever, responseSynthesizer);
+    return new RetrieverQueryEngine(
+      retriever,
+      responseSynthesizer,
+      undefined,
+      options?.nodePostprocessors,
+    );
   }
 
   static async buildIndexFromNodes(
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index 08499e5e9..ff34df502 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -18,6 +18,7 @@ import {
   IndexDict,
   IndexStructType,
 } from "../BaseIndex";
+import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import { VectorIndexRetriever } from "./VectorIndexRetriever";
 
 export interface VectorIndexOptions {
@@ -246,11 +247,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
   asQueryEngine(options?: {
     retriever?: BaseRetriever;
     responseSynthesizer?: ResponseSynthesizer;
+    nodePostprocessors?: BaseNodePostprocessor[];
   }): BaseQueryEngine {
     const { retriever, responseSynthesizer } = options ?? {};
     return new RetrieverQueryEngine(
       retriever ?? this.asRetriever(),
       responseSynthesizer,
+      undefined,
+      options?.nodePostprocessors,
     );
   }
 
-- 
GitLab