diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py index 186e4b899388fcc8ed58df42f1f7311aa07a4849..a087f95f982e7b1f6fa821a60547db450ef013a0 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py @@ -4,7 +4,7 @@ from llama_index.core.node_parser.relational.base_element import ( BaseElementNodeParser, Element, ) -from llama_index.core.schema import BaseNode, TextNode +from llama_index.core.schema import BaseNode, TextNode, NodeRelationship from llama_index.core.node_parser.relational.utils import md_to_df @@ -32,7 +32,11 @@ class MarkdownElementNodeParser(BaseElementNodeParser): self.extract_table_summaries(table_elements) # convert into nodes # will return a list of Nodes and Index Nodes - return self.get_nodes_from_elements(elements, node.metadata) + nodes = self.get_nodes_from_elements(elements, node.metadata) + source_document = node.source_node or node.as_related_node_info() + for n in nodes: + n.relationships[NodeRelationship.SOURCE] = source_document + return nodes def extract_elements( self, diff --git a/llama-index-core/tests/node_parser/test_markdown_element.py b/llama-index-core/tests/node_parser/test_markdown_element.py index 4f4b2dddaf6a1851dad2184c175849848f5496da..a15861c0b781419c620e737522c83500d3b0a40c 100644 --- a/llama-index-core/tests/node_parser/test_markdown_element.py +++ b/llama-index-core/tests/node_parser/test_markdown_element.py @@ -2649,3 +2649,19 @@ Llama 2 is a new technology that carries risks with use. Testing conducted to da nodes = node_parser.get_nodes_from_documents([test_data]) assert len(nodes) == 224 + + +def test_extract_ref_doc_id(): + test_document = Document( + text=""" +# Introduction +Hello world! +""", + ) + + node_parser = MarkdownElementNodeParser(llm=MockLLM()) + + nodes = node_parser.get_nodes_from_documents([test_document]) + assert len(nodes) == 1 + + assert nodes[0].ref_doc_id == test_document.doc_id