From 2ebdb23630347876e2915cffb338fc1d17b4439f Mon Sep 17 00:00:00 2001
From: Haotian Zhang <socool.king@gmail.com>
Date: Wed, 28 Feb 2024 00:12:35 -0500
Subject: [PATCH] improve multi doc retrieval (#11346)

* improve multi doc retrieval

* cr

* cr

* cr

* cr
---
 .../llama_index/core/node_parser/interface.py         |  1 +
 .../core/node_parser/relational/base_element.py       | 11 +++++++++--
 .../core/node_parser/relational/markdown_element.py   |  2 +-
 .../node_parser/relational/unstructured_element.py    |  2 +-
 4 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/llama-index-core/llama_index/core/node_parser/interface.py b/llama-index-core/llama_index/core/node_parser/interface.py
index e7ca5b13aa..198410eca0 100644
--- a/llama-index-core/llama_index/core/node_parser/interface.py
+++ b/llama-index-core/llama_index/core/node_parser/interface.py
@@ -1,4 +1,5 @@
 """Node parser interface."""
+
 from abc import ABC, abstractmethod
 from typing import Any, Callable, List, Sequence
 
diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
index d5870793cb..f14809e149 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
@@ -249,7 +249,11 @@ class BaseElementNodeParser(NodeParser):
         doc = Document(text="\n\n".join(list(buffer)))
         return node_parser.get_nodes_from_documents([doc])
 
-    def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
+    def get_nodes_from_elements(
+        self,
+        elements: List[Element],
+        metadata_inherited: Optional[Dict[str, Any]] = None,
+    ) -> List[BaseNode]:
         """Get nodes and mappings."""
         from llama_index.core.node_parser import SentenceSplitter
 
@@ -342,5 +346,8 @@ class BaseElementNodeParser(NodeParser):
             nodes.extend(cur_text_nodes)
             cur_text_el_buffer = []
 
-        # remove empty nodes
+        # remove empty nodes and keep node original metadata inherited from parent nodes
+        for node in nodes:
+            if metadata_inherited:
+                node.metadata.update(metadata_inherited)
         return [node for node in nodes if len(node.text) > 0]
diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
index 9b2320055c..4ee8fbedf1 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
@@ -57,7 +57,7 @@ class MarkdownElementNodeParser(BaseElementNodeParser):
         self.extract_table_summaries(table_elements)
         # convert into nodes
         # will return a list of Nodes and Index Nodes
-        return self.get_nodes_from_elements(elements)
+        return self.get_nodes_from_elements(elements, node.metadata)
 
     def extract_elements(
         self,
diff --git a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
index e19124addc..9b29737cbf 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py
@@ -92,7 +92,7 @@ class UnstructuredElementNodeParser(BaseElementNodeParser):
         self.extract_table_summaries(table_elements)
         # convert into nodes
         # will return a list of Nodes and Index Nodes
-        return self.get_nodes_from_elements(elements)
+        return self.get_nodes_from_elements(elements, node.metadata)
 
     def extract_elements(
         self, text: str, table_filters: Optional[List[Callable]] = None, **kwargs: Any
-- 
GitLab