From 2ebdb23630347876e2915cffb338fc1d17b4439f Mon Sep 17 00:00:00 2001 From: Haotian Zhang <socool.king@gmail.com> Date: Wed, 28 Feb 2024 00:12:35 -0500 Subject: [PATCH] improve multi doc retrieval (#11346) * improve multi doc retrieval * cr * cr * cr * cr --- .../llama_index/core/node_parser/interface.py | 1 + .../core/node_parser/relational/base_element.py | 11 +++++++++-- .../core/node_parser/relational/markdown_element.py | 2 +- .../node_parser/relational/unstructured_element.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/llama-index-core/llama_index/core/node_parser/interface.py b/llama-index-core/llama_index/core/node_parser/interface.py index e7ca5b13aa..198410eca0 100644 --- a/llama-index-core/llama_index/core/node_parser/interface.py +++ b/llama-index-core/llama_index/core/node_parser/interface.py @@ -1,4 +1,5 @@ """Node parser interface.""" + from abc import ABC, abstractmethod from typing import Any, Callable, List, Sequence diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py index d5870793cb..f14809e149 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py @@ -249,7 +249,11 @@ class BaseElementNodeParser(NodeParser): doc = Document(text="\n\n".join(list(buffer))) return node_parser.get_nodes_from_documents([doc]) - def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]: + def get_nodes_from_elements( + self, + elements: List[Element], + metadata_inherited: Optional[Dict[str, Any]] = None, + ) -> List[BaseNode]: """Get nodes and mappings.""" from llama_index.core.node_parser import SentenceSplitter @@ -342,5 +346,8 @@ class BaseElementNodeParser(NodeParser): nodes.extend(cur_text_nodes) cur_text_el_buffer = [] - # remove empty nodes + # remove empty nodes and keep node original metadata inherited from parent nodes + for node in nodes: + if metadata_inherited: + node.metadata.update(metadata_inherited) return [node for node in nodes if len(node.text) > 0] diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py index 9b2320055c..4ee8fbedf1 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py @@ -57,7 +57,7 @@ class MarkdownElementNodeParser(BaseElementNodeParser): self.extract_table_summaries(table_elements) # convert into nodes # will return a list of Nodes and Index Nodes - return self.get_nodes_from_elements(elements) + return self.get_nodes_from_elements(elements, node.metadata) def extract_elements( self, diff --git a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py index e19124addc..9b29737cbf 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py @@ -92,7 +92,7 @@ class UnstructuredElementNodeParser(BaseElementNodeParser): self.extract_table_summaries(table_elements) # convert into nodes # will return a list of Nodes and Index Nodes - return self.get_nodes_from_elements(elements) + return self.get_nodes_from_elements(elements, node.metadata) def extract_elements( self, text: str, table_filters: Optional[List[Callable]] = None, **kwargs: Any -- GitLab