From 993d6094b2f649bc1327429898cd0436a64cc582 Mon Sep 17 00:00:00 2001
From: Pratik Singh Chauhan <pratiksingh773@gmail.com>
Date: Mon, 20 May 2024 08:08:00 +0530
Subject: [PATCH] Pass 'exclude_llm_metadata_keys' and
 'exclude_embed_metadata_keys'  in Node Parsers (#13567)

---
 .../core/node_parser/relational/base_element.py      | 12 +++++++++---
 .../relational/llama_parse_json_element.py           |  2 +-
 .../core/node_parser/relational/markdown_element.py  |  2 +-
 .../node_parser/relational/unstructured_element.py   |  2 +-
 4 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
index 938b9b5227..26bdf060b9 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
@@ -245,7 +245,7 @@ class BaseElementNodeParser(NodeParser):
     def get_nodes_from_elements(
         self,
         elements: List[Element],
-        metadata_inherited: Optional[Dict[str, Any]] = None,
+        node_inherited: Optional[TextNode] = None,
         ref_doc_text: Optional[str] = None,
     ) -> List[BaseNode]:
         """Get nodes and mappings."""
@@ -357,8 +357,14 @@ class BaseElementNodeParser(NodeParser):
 
         # remove empty nodes and keep node original metadata inherited from parent nodes
         for node in nodes:
-            if metadata_inherited:
-                node.metadata.update(metadata_inherited)
+            if node_inherited and node_inherited.metadata:
+                node.metadata.update(node_inherited.metadata)
+                node.excluded_embed_metadata_keys = (
+                    node_inherited.excluded_embed_metadata_keys
+                )
+                node.excluded_llm_metadata_keys = (
+                    node_inherited.excluded_llm_metadata_keys
+                )
         return [node for node in nodes if len(node.text) > 0]
 
     def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
diff --git a/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py b/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py
index 013db27743..c1a6e6ea9a 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py
@@ -33,7 +33,7 @@ class LlamaParseJsonNodeParser(BaseElementNodeParser):
         self.extract_table_summaries(table_elements)
         # convert into nodes
         # will return a list of Nodes and Index Nodes
-        return self.get_nodes_from_elements(elements, node.metadata)
+        return self.get_nodes_from_elements(elements, node)
 
     def extract_elements(
         self,
diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
index bd65ac55bd..59ddd33a97 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
@@ -33,7 +33,7 @@ class MarkdownElementNodeParser(BaseElementNodeParser):
         # convert into nodes
         # will return a list of Nodes and Index Nodes
         nodes = self.get_nodes_from_elements(
-            elements, node.metadata, ref_doc_text=node.get_content()
+            elements, node, ref_doc_text=node.get_content()
         )
         source_document = node.source_node or node.as_related_node_info()
         for n in nodes:
diff --git a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
index 43e684c19d..698e397550 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
@@ -67,7 +67,7 @@ class UnstructuredElementNodeParser(BaseElementNodeParser):
         self.extract_table_summaries(table_elements)
         # convert into nodes
         # will return a list of Nodes and Index Nodes
-        nodes = self.get_nodes_from_elements(elements, node.metadata)
+        nodes = self.get_nodes_from_elements(elements, node)
 
         source_document = node.source_node or node.as_related_node_info()
         for n in nodes:
-- 
GitLab