From 2200c13cd70b5c05aa5dd16b55e4901977659173 Mon Sep 17 00:00:00 2001
From: Javier Torres <javierandrestorresreyes@gmail.com>
Date: Fri, 5 Apr 2024 18:25:02 -0500
Subject: [PATCH] Fix MarkdownNodeParser ref_doc_id (#12615)

---
 .../node_parser/relational/markdown_element.py   |  8 ++++++--
 .../tests/node_parser/test_markdown_element.py   | 16 ++++++++++++++++
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
index 186e4b8993..a087f95f98 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py
@@ -4,7 +4,7 @@ from llama_index.core.node_parser.relational.base_element import (
     BaseElementNodeParser,
     Element,
 )
-from llama_index.core.schema import BaseNode, TextNode
+from llama_index.core.schema import BaseNode, TextNode, NodeRelationship
 from llama_index.core.node_parser.relational.utils import md_to_df
 
 
@@ -32,7 +32,11 @@ class MarkdownElementNodeParser(BaseElementNodeParser):
         self.extract_table_summaries(table_elements)
         # convert into nodes
         # will return a list of Nodes and Index Nodes
-        return self.get_nodes_from_elements(elements, node.metadata)
+        nodes = self.get_nodes_from_elements(elements, node.metadata)
+        source_document = node.source_node or node.as_related_node_info()
+        for n in nodes:
+            n.relationships[NodeRelationship.SOURCE] = source_document
+        return nodes
 
     def extract_elements(
         self,
diff --git a/llama-index-core/tests/node_parser/test_markdown_element.py b/llama-index-core/tests/node_parser/test_markdown_element.py
index 4f4b2dddaf..a15861c0b7 100644
--- a/llama-index-core/tests/node_parser/test_markdown_element.py
+++ b/llama-index-core/tests/node_parser/test_markdown_element.py
@@ -2649,3 +2649,19 @@ Llama 2 is a new technology that carries risks with use. Testing conducted to da
 
     nodes = node_parser.get_nodes_from_documents([test_data])
     assert len(nodes) == 224
+
+
+def test_extract_ref_doc_id():
+    test_document = Document(
+        text="""
+# Introduction
+Hello world!
+""",
+    )
+
+    node_parser = MarkdownElementNodeParser(llm=MockLLM())
+
+    nodes = node_parser.get_nodes_from_documents([test_document])
+    assert len(nodes) == 1
+
+    assert nodes[0].ref_doc_id == test_document.doc_id
-- 
GitLab