From 733d62a699bbbaace58cb439ace347fe7f6bc2db Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Fri, 5 Jan 2024 14:37:14 +0700 Subject: [PATCH] feat: add MetadataReplacementPostProcessor --- packages/core/src/ChatEngine.ts | 4 +-- packages/core/src/QueryEngine.ts | 9 +++-- packages/core/src/indices/index.ts | 1 - .../src/indices/keyword/KeywordTableIndex.ts | 9 +++-- .../core/src/indices/summary/SummaryIndex.ts | 7 ++-- .../indices/vectorStore/VectorStoreIndex.ts | 8 ++--- .../MetadataReplacementPostProcessor.ts | 21 ++++++++++++ .../SimilarityPostprocessor.ts} | 5 +-- packages/core/src/postprocessors/index.ts | 3 ++ packages/core/src/postprocessors/types.ts | 5 +++ .../MetadataReplacementPostProcessor.test.ts | 33 +++++++++++++++++++ 11 files changed, 85 insertions(+), 20 deletions(-) create mode 100644 packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts rename packages/core/src/{indices/BaseNodePostprocessor.ts => postprocessors/SimilarityPostprocessor.ts} (81%) create mode 100644 packages/core/src/postprocessors/index.ts create mode 100644 packages/core/src/postprocessors/types.ts create mode 100644 packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts diff --git a/packages/core/src/ChatEngine.ts b/packages/core/src/ChatEngine.ts index 5311ffa8a..1afbb13d2 100644 --- a/packages/core/src/ChatEngine.ts +++ b/packages/core/src/ChatEngine.ts @@ -13,8 +13,8 @@ import { Response } from "./Response"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { Event } from "./callbacks/CallbackManager"; -import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; -import { ChatMessage, LLM, OpenAI } from "./llm/LLM"; +import { ChatMessage, LLM, OpenAI } from "./llm"; +import { BaseNodePostprocessor } from "./postprocessors"; /** * A ChatEngine is used to handle back and forth chats between the application and the LLM. diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index 8e032fbf8..aedbb9be7 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -10,9 +10,12 @@ import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { QueryEngineTool, ToolMetadata } from "./Tool"; import { Event } from "./callbacks/CallbackManager"; -import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; -import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers"; -import { BaseSynthesizer } from "./synthesizers/types"; +import { BaseNodePostprocessor } from "./postprocessors"; +import { + BaseSynthesizer, + CompactAndRefine, + ResponseSynthesizer, +} from "./synthesizers"; /** * A query engine is a question answerer that can use one or more steps. diff --git a/packages/core/src/indices/index.ts b/packages/core/src/indices/index.ts index ddfe185dc..8bda05b2d 100644 --- a/packages/core/src/indices/index.ts +++ b/packages/core/src/indices/index.ts @@ -1,5 +1,4 @@ export * from "./BaseIndex"; -export * from "./BaseNodePostprocessor"; export * from "./keyword"; export * from "./summary"; export * from "./vectorStore"; diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts index 406809d88..5b67cec67 100644 --- a/packages/core/src/indices/keyword/KeywordTableIndex.ts +++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts @@ -6,8 +6,12 @@ import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; -import { StorageContext, storageContextFromDefaults } from "../../storage"; -import { BaseDocumentStore } from "../../storage/docStore/types"; +import { BaseNodePostprocessor } from "../../postprocessors"; +import { + BaseDocumentStore, + StorageContext, + storageContextFromDefaults, +} from "../../storage"; import { BaseSynthesizer } from "../../synthesizers"; import { BaseIndex, @@ -15,7 +19,6 @@ import { IndexStructType, KeywordTable, } from "../BaseIndex"; -import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { KeywordTableLLMRetriever, KeywordTableRAKERetriever, diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts index eb6f75333..391685dd8 100644 --- a/packages/core/src/indices/summary/SummaryIndex.ts +++ b/packages/core/src/indices/summary/SummaryIndex.ts @@ -6,11 +6,13 @@ import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; +import { BaseNodePostprocessor } from "../../postprocessors"; import { + BaseDocumentStore, + RefDocInfo, StorageContext, storageContextFromDefaults, -} from "../../storage/StorageContext"; -import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; +} from "../../storage"; import { BaseSynthesizer, CompactAndRefine, @@ -22,7 +24,6 @@ import { IndexList, IndexStructType, } from "../BaseIndex"; -import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { SummaryIndexLLMRetriever, SummaryIndexRetriever, diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index 278209f2c..59d177b34 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -17,12 +17,13 @@ import { ClipEmbedding, MultiModalEmbedding, } from "../../embeddings"; +import { BaseNodePostprocessor } from "../../postprocessors"; import { + BaseIndexStore, StorageContext, + VectorStore, storageContextFromDefaults, -} from "../../storage/StorageContext"; -import { BaseIndexStore } from "../../storage/indexStore/types"; -import { VectorStore } from "../../storage/vectorStore/types"; +} from "../../storage"; import { BaseSynthesizer } from "../../synthesizers"; import { BaseIndex, @@ -30,7 +31,6 @@ import { IndexDict, IndexStructType, } from "../BaseIndex"; -import { BaseNodePostprocessor } from "../BaseNodePostprocessor"; import { VectorIndexRetriever } from "./VectorIndexRetriever"; interface IndexStructOptions { diff --git a/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts new file mode 100644 index 000000000..d05d94a10 --- /dev/null +++ b/packages/core/src/postprocessors/MetadataReplacementPostProcessor.ts @@ -0,0 +1,21 @@ +import { MetadataMode, NodeWithScore } from "../Node"; +import { BaseNodePostprocessor } from "./types"; + +export class MetadataReplacementPostProcessor implements BaseNodePostprocessor { + targetMetadataKey: string; + + constructor(targetMetadataKey: string) { + this.targetMetadataKey = targetMetadataKey; + } + + postprocessNodes(nodes: NodeWithScore[]): NodeWithScore[] { + for (let n of nodes) { + n.node.setContent( + n.node.metadata[this.targetMetadataKey] ?? + n.node.getContent(MetadataMode.NONE), + ); + } + + return nodes; + } +} diff --git a/packages/core/src/indices/BaseNodePostprocessor.ts b/packages/core/src/postprocessors/SimilarityPostprocessor.ts similarity index 81% rename from packages/core/src/indices/BaseNodePostprocessor.ts rename to packages/core/src/postprocessors/SimilarityPostprocessor.ts index b47072894..91674515e 100644 --- a/packages/core/src/indices/BaseNodePostprocessor.ts +++ b/packages/core/src/postprocessors/SimilarityPostprocessor.ts @@ -1,8 +1,5 @@ import { NodeWithScore } from "../Node"; - -export interface BaseNodePostprocessor { - postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[]; -} +import { BaseNodePostprocessor } from "./types"; export class SimilarityPostprocessor implements BaseNodePostprocessor { similarityCutoff?: number; diff --git a/packages/core/src/postprocessors/index.ts b/packages/core/src/postprocessors/index.ts new file mode 100644 index 000000000..f79e4ced0 --- /dev/null +++ b/packages/core/src/postprocessors/index.ts @@ -0,0 +1,3 @@ +export * from "./MetadataReplacementPostProcessor"; +export * from "./SimilarityPostprocessor"; +export * from "./types"; diff --git a/packages/core/src/postprocessors/types.ts b/packages/core/src/postprocessors/types.ts new file mode 100644 index 000000000..2d0c73e78 --- /dev/null +++ b/packages/core/src/postprocessors/types.ts @@ -0,0 +1,5 @@ +import { NodeWithScore } from "../Node"; + +export interface BaseNodePostprocessor { + postprocessNodes: (nodes: NodeWithScore[]) => NodeWithScore[]; +} diff --git a/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts new file mode 100644 index 000000000..f9a76845d --- /dev/null +++ b/packages/core/src/tests/postprocessors/MetadataReplacementPostProcessor.test.ts @@ -0,0 +1,33 @@ +import { MetadataMode, NodeWithScore, TextNode } from "../../Node"; +import { MetadataReplacementPostProcessor } from "../../postprocessors"; + +describe("MetadataReplacementPostProcessor", () => { + let postProcessor: MetadataReplacementPostProcessor; + let nodes: NodeWithScore[]; + + beforeEach(() => { + postProcessor = new MetadataReplacementPostProcessor("targetKey"); + + nodes = [ + { + node: new TextNode({ + text: "OldContent", + }), + score: 5, + }, + ]; + }); + + test("Replaces the content of each node with specified metadata key if it exists", () => { + nodes[0].node.metadata = { targetKey: "NewContent" }; + const newNodes = postProcessor.postprocessNodes(nodes); + // Check if node content was replaced correctly + expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("NewContent"); + }); + + test("Retains the original content of each node if no metadata key is found", () => { + const newNodes = postProcessor.postprocessNodes(nodes); + // Check if node content remained unchanged + expect(newNodes[0].node.getContent(MetadataMode.NONE)).toBe("OldContent"); + }); +}); -- GitLab