From 733d62a699bbbaace58cb439ace347fe7f6bc2db Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Fri, 5 Jan 2024 14:37:14 +0700
Subject: [PATCH] feat: add MetadataReplacementPostProcessor

---
 packages/core/src/ChatEngine.ts               |  4 +--
 packages/core/src/QueryEngine.ts              |  9 +++--
 packages/core/src/indices/index.ts            |  1 -
 .../src/indices/keyword/KeywordTableIndex.ts  |  9 +++--
 .../core/src/indices/summary/SummaryIndex.ts  |  7 ++--
 .../indices/vectorStore/VectorStoreIndex.ts   |  8 ++---
 .../MetadataReplacementPostProcessor.ts       | 21 ++++++++++++
 .../SimilarityPostprocessor.ts}               |  5 +--
 packages/core/src/postprocessors/index.ts     |  3 ++
 packages/core/src/postprocessors/types.ts     |  5 +++
 .../MetadataReplacementPostProcessor.test.ts  | 33 +++++++++++++++++++
 11 files changed, 85 insertions(+), 20 deletions(-)
 create mode 100644 packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts
 rename packages/core/src/{indices/BaseNodePostprocessor.ts => postprocessors/SimilarityPostprocessor.ts} (81%)
 create mode 100644 packages/core/src/postprocessors/index.ts
 create mode 100644 packages/core/src/postprocessors/types.ts
 create mode 100644 packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts

diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts
index 5311ffa8a..1afbb13d2 100644
--- a/packages/core/src/ChatEngine.ts
+++ b/packages/core/src/ChatEngine.ts
@@ -13,8 +13,8 @@ import { Response } from "./Response";
 import { BaseRetriever } from "./Retriever";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
 import { Event } from "./callbacks/CallbackManager";
-import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
-import { ChatMessage, LLM, OpenAI } from "./llm/LLM";
+import { ChatMessage, LLM, OpenAI } from "./llm";
+import { BaseNodePostprocessor } from "./postprocessors";
 
 /**
  * A ChatEngine is used to handle back and forth chats between the application and the LLM.
diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts
index 8e032fbf8..aedbb9be7 100644
--- a/packages/core/src/QueryEngine.ts
+++ b/packages/core/src/QueryEngine.ts
@@ -10,9 +10,12 @@ import { BaseRetriever } from "./Retriever";
 import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
 import { QueryEngineTool, ToolMetadata } from "./Tool";
 import { Event } from "./callbacks/CallbackManager";
-import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
-import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers";
-import { BaseSynthesizer } from "./synthesizers/types";
+import { BaseNodePostprocessor } from "./postprocessors";
+import {
+  BaseSynthesizer,
+  CompactAndRefine,
+  ResponseSynthesizer,
+} from "./synthesizers";
 
 /**
  * A query engine is a question answerer that can use one or more steps.
diff --git a/packages/core/src/indices/index.ts b/packages/core/src/indices/index.ts
index ddfe185dc..8bda05b2d 100644
--- a/packages/core/src/indices/index.ts
+++ b/packages/core/src/indices/index.ts
@@ -1,5 +1,4 @@
 export * from "./BaseIndex";
-export * from "./BaseNodePostprocessor";
 export * from "./keyword";
 export * from "./summary";
 export * from "./vectorStore";
diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts
index 406809d88..5b67cec67 100644
--- a/packages/core/src/indices/keyword/KeywordTableIndex.ts
+++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts
@@ -6,8 +6,12 @@ import {
   ServiceContext,
   serviceContextFromDefaults,
 } from "../../ServiceContext";
-import { StorageContext, storageContextFromDefaults } from "../../storage";
-import { BaseDocumentStore } from "../../storage/docStore/types";
+import { BaseNodePostprocessor } from "../../postprocessors";
+import {
+  BaseDocumentStore,
+  StorageContext,
+  storageContextFromDefaults,
+} from "../../storage";
 import { BaseSynthesizer } from "../../synthesizers";
 import {
   BaseIndex,
@@ -15,7 +19,6 @@ import {
   IndexStructType,
   KeywordTable,
 } from "../BaseIndex";
-import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import {
   KeywordTableLLMRetriever,
   KeywordTableRAKERetriever,
diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts
index eb6f75333..391685dd8 100644
--- a/packages/core/src/indices/summary/SummaryIndex.ts
+++ b/packages/core/src/indices/summary/SummaryIndex.ts
@@ -6,11 +6,13 @@ import {
   ServiceContext,
   serviceContextFromDefaults,
 } from "../../ServiceContext";
+import { BaseNodePostprocessor } from "../../postprocessors";
 import {
+  BaseDocumentStore,
+  RefDocInfo,
   StorageContext,
   storageContextFromDefaults,
-} from "../../storage/StorageContext";
-import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
+} from "../../storage";
 import {
   BaseSynthesizer,
   CompactAndRefine,
@@ -22,7 +24,6 @@ import {
   IndexList,
   IndexStructType,
 } from "../BaseIndex";
-import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import {
   SummaryIndexLLMRetriever,
   SummaryIndexRetriever,
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index 278209f2c..59d177b34 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -17,12 +17,13 @@ import {
   ClipEmbedding,
   MultiModalEmbedding,
 } from "../../embeddings";
+import { BaseNodePostprocessor } from "../../postprocessors";
 import {
+  BaseIndexStore,
   StorageContext,
+  VectorStore,
   storageContextFromDefaults,
-} from "../../storage/StorageContext";
-import { BaseIndexStore } from "../../storage/indexStore/types";
-import { VectorStore } from "../../storage/vectorStore/types";
+} from "../../storage";
 import { BaseSynthesizer } from "../../synthesizers";
 import {
   BaseIndex,
@@ -30,7 +31,6 @@ import {
   IndexDict,
   IndexStructType,
 } from "../BaseIndex";
-import { BaseNodePostprocessor } from "../BaseNodePostprocessor";
 import { VectorIndexRetriever } from "./VectorIndexRetriever";
 
 interface IndexStructOptions {
diff --git a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts
new file mode 100644
index 000000000..d05d94a10
--- /dev/null
+++ b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts
@@ -0,0 +1,21 @@
+import { MetadataMode, NodeWithScore } from "../Node";
+import { BaseNodePostprocessor } from "./types";
+
+export class MetadataReplacementPostProcessor implements BaseNodePostprocessor {
+  targetMetadataKey: string;
+
+  constructor(targetMetadataKey: string) {
+    this.targetMetadataKey = targetMetadataKey;
+  }
+
+  postprocessNodes(nodes: NodeWithScore[]): NodeWithScore[] {
+    for (let n of nodes) {
+      n.node.setContent(
+        n.node.metadata[this.targetMetadataKey] ??
+          n.node.getContent(MetadataMode.NONE),
+      );
+    }
+
+    return nodes;
+  }
+}
diff --git a/packages/core/src/indices/BaseNodePostprocessor.ts b/packages/core/src/postprocessors/SimilarityPostprocessor.ts
similarity index 81%
rename from packages/core/src/indices/BaseNodePostprocessor.ts
rename to packages/core/src/postprocessors/SimilarityPostprocessor.ts
index b47072894..91674515e 100644
--- a/packages/core/src/indices/BaseNodePostprocessor.ts
+++ b/packages/core/src/postprocessors/SimilarityPostprocessor.ts
@@ -1,8 +1,5 @@
 import { NodeWithScore } from "../Node";
-
-export interface BaseNodePostprocessor {
-  postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[];
-}
+import { BaseNodePostprocessor } from "./types";
 
 export class SimilarityPostprocessor implements BaseNodePostprocessor {
   similarityCutoff?: number;
diff --git a/packages/core/src/postprocessors/index.ts b/packages/core/src/postprocessors/index.ts
new file mode 100644
index 000000000..f79e4ced0
--- /dev/null
+++ b/packages/core/src/postprocessors/index.ts
@@ -0,0 +1,3 @@
+export * from "./MetadataReplacementPostProcessor";
+export * from "./SimilarityPostprocessor";
+export * from "./types";
diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts
new file mode 100644
index 000000000..2d0c73e78
--- /dev/null
+++ b/packages/core/src/postprocessors/types.ts
@@ -0,0 +1,5 @@
+import { NodeWithScore } from "../Node";
+
+export interface BaseNodePostprocessor {
+  postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[];
+}
diff --git a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts
new file mode 100644
index 000000000..f9a76845d
--- /dev/null
+++ b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts
@@ -0,0 +1,33 @@
+import { MetadataMode, NodeWithScore, TextNode } from "../../Node";
+import { MetadataReplacementPostProcessor } from "../../postprocessors";
+
+describe("MetadataReplacementPostProcessor", () => {
+  let postProcessor: MetadataReplacementPostProcessor;
+  let nodes: NodeWithScore[];
+
+  beforeEach(() => {
+    postProcessor = new MetadataReplacementPostProcessor("targetKey");
+
+    nodes = [
+      {
+        node: new TextNode({
+          text: "OldContent",
+        }),
+        score: 5,
+      },
+    ];
+  });
+
+  test("Replaces the content of each node with specified metadata key if it exists", () => {
+    nodes[0].node.metadata = { targetKey: "NewContent" };
+    const newNodes = postProcessor.postprocessNodes(nodes);
+    // Check if node content was replaced correctly
+    expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("NewContent");
+  });
+
+  test("Retains the original content of each node if no metadata key is found", () => {
+    const newNodes = postProcessor.postprocessNodes(nodes);
+    // Check if node content remained unchanged
+    expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("OldContent");
+  });
+});
-- 
GitLab