From 371ef2cfe7c784f403fa0fd65f0c8133acd9bfc6 Mon Sep 17 00:00:00 2001
From: Haotian Zhang <socool.king@gmail.com>
Date: Wed, 7 Feb 2024 13:10:38 -0800
Subject: [PATCH] Fix MD parser for inconsistency tables (#10488)

* Fix MD parser for inconsistency tables

* cf

* cr

* comment

* unit tests

* fix
---
 .../node_parser/relational/base_element.py    | 53 +++++++++++--------
 .../relational/markdown_element.py            | 26 ++++++---
 tests/node_parser/test_markdown_element.py    |  7 ++-
 3 files changed, 55 insertions(+), 31 deletions(-)

diff --git a/llama_index/node_parser/relational/base_element.py b/llama_index/node_parser/relational/base_element.py
index 106dd10b6..3dcbf1887 100644
--- a/llama_index/node_parser/relational/base_element.py
+++ b/llama_index/node_parser/relational/base_element.py
@@ -126,7 +126,7 @@ class BaseElementNodeParser(NodeParser):
 
     def get_table_elements(self, elements: List[Element]) -> List[Element]:
         """Get table elements."""
-        return [e for e in elements if e.type == "table"]
+        return [e for e in elements if e.type == "table" or e.type == "table_text"]
 
     def get_text_elements(self, elements: List[Element]) -> List[Element]:
         """Get text elements."""
@@ -146,7 +146,7 @@ class BaseElementNodeParser(NodeParser):
 
         table_context_list = []
         for idx, element in tqdm(enumerate(elements)):
-            if element.type != "table":
+            if element.type not in ("table", "table_text"):
                 continue
             table_context = str(element.element)
             if idx > 0 and str(elements[idx - 1].element).lower().strip().startswith(
@@ -249,8 +249,8 @@ class BaseElementNodeParser(NodeParser):
         nodes = []
         cur_text_el_buffer: List[str] = []
         for element in elements:
-            if element.type == "table":
-                # flush text buffer
+            if element.type == "table" or element.type == "table_text":
+                # flush text buffer for table
                 if len(cur_text_el_buffer) > 0:
                     cur_text_nodes = self._get_nodes_from_buffer(
                         cur_text_el_buffer, node_parser
@@ -259,7 +259,27 @@ class BaseElementNodeParser(NodeParser):
                     cur_text_el_buffer = []
 
                 table_output = cast(TableOutput, element.table_output)
-                table_df = cast(pd.DataFrame, element.table)
+                table_md = ""
+                if element.type == "table":
+                    table_df = cast(pd.DataFrame, element.table)
+                    # We serialize the table as markdown as it allow better accuracy
+                    # We do not use the table_df.to_markdown() method as it generate
+                    # a table with a token hungry format.
+                    table_md = "|"
+                    for col_name, col in table_df.items():
+                        table_md += f"{col_name}|"
+                    table_md += "\n|"
+                    for col_name, col in table_df.items():
+                        table_md += f"---|"
+                    table_md += "\n"
+                    for row in table_df.itertuples():
+                        table_md += "|"
+                        for col in row[1:]:
+                            table_md += f"{col}|"
+                        table_md += "\n"
+                elif element.type == "table_text":
+                    # if the table is non-perfect table, we still want to keep the original text of table
+                    table_md = str(element.element)
                 table_id = element.id + "_table"
                 table_ref_id = element.id + "_table_ref"
 
@@ -284,29 +304,16 @@ class BaseElementNodeParser(NodeParser):
                     index_id=table_id,
                 )
 
-                # We serialize the table as markdown as it allow better accuracy
-                # We do not use the table_df.to_markdown() method as it generate
-                # a table with a token hngry format.
-                table_md = "|"
-                for col_name, col in table_df.items():
-                    table_md += f"{col_name}|"
-                table_md += "\n|"
-                for col_name, col in table_df.items():
-                    table_md += f"---|"
-                table_md += "\n"
-                for row in table_df.itertuples():
-                    table_md += "|"
-                    for col in row[1:]:
-                        table_md += f"{col}|"
-                    table_md += "\n"
-
                 table_str = table_summary + "\n" + table_md
+
                 text_node = TextNode(
                     text=table_str,
                     id_=table_id,
                     metadata={
-                        # serialize the table as a dictionary string
-                        "table_df": str(table_df.to_dict()),
+                        # serialize the table as a dictionary string for dataframe of perfect table
+                        "table_df": str(table_df.to_dict())
+                        if element.type == "table"
+                        else table_md,
                         # add table summary for retrieval purposes
                         "table_summary": table_summary,
                     },
diff --git a/llama_index/node_parser/relational/markdown_element.py b/llama_index/node_parser/relational/markdown_element.py
index 4e1768944..7abea2608 100644
--- a/llama_index/node_parser/relational/markdown_element.py
+++ b/llama_index/node_parser/relational/markdown_element.py
@@ -144,27 +144,41 @@ class MarkdownElementNodeParser(BaseElementNodeParser):
         for idx, element in enumerate(elements):
             if element.type == "table":
                 should_keep = True
+                perfect_table = True
 
                 # verify that the table (markdown) have the same number of columns on each rows
                 table_lines = element.element.split("\n")
                 table_columns = [len(line.split("|")) for line in table_lines]
                 if len(set(table_columns)) > 1:
-                    should_keep = False
+                    # if the table have different number of columns on each rows, it's not a perfect table
+                    # we will store the raw text for such tables instead of converting them to a dataframe
+                    perfect_table = False
 
                 # verify that the table (markdown) have at least 2 rows
                 if len(table_lines) < 2:
                     should_keep = False
 
                 # apply the table filter, now only filter empty tables
-                if should_keep and table_filters is not None:
+                if should_keep and perfect_table and table_filters is not None:
                     should_keep = all(tf(element) for tf in table_filters)
 
                 # if the element is a table, convert it to a dataframe
                 if should_keep:
-                    table = md_to_df(element.element)
-                    elements[idx] = Element(
-                        id=f"id_{idx}", type="table", element=element, table=table
-                    )
+                    if perfect_table:
+                        table = md_to_df(element.element)
+
+                        elements[idx] = Element(
+                            id=f"id_{idx}", type="table", element=element, table=table
+                        )
+                    else:
+                        # for non-perfect tables, we will store the raw text
+                        # and give it a different type to differentiate it from perfect tables
+                        elements[idx] = Element(
+                            id=f"id_{idx}",
+                            type="table_text",
+                            element=element,
+                            # table=table
+                        )
                 else:
                     elements[idx] = Element(
                         id=f"id_{idx}",
diff --git a/tests/node_parser/test_markdown_element.py b/tests/node_parser/test_markdown_element.py
index 96b593109..597968e5a 100644
--- a/tests/node_parser/test_markdown_element.py
+++ b/tests/node_parser/test_markdown_element.py
@@ -76,10 +76,13 @@ def test_md_table_extraction_broken_table() -> None:
     print(f"Number of nodes: {len(nodes)}")
     for i, node in enumerate(nodes, start=0):
         print(f"Node {i}: {node}, Type: {type(node)}")
-    assert len(nodes) == 3
+    assert len(nodes) == 6
     assert isinstance(nodes[0], TextNode)
     assert isinstance(nodes[1], IndexNode)
     assert isinstance(nodes[2], TextNode)
+    assert isinstance(nodes[3], TextNode)
+    assert isinstance(nodes[4], IndexNode)
+    assert isinstance(nodes[5], TextNode)
 
 
 def test_complex_md() -> None:
@@ -2645,4 +2648,4 @@ Llama 2 is a new technology that carries risks with use. Testing conducted to da
     node_parser = MarkdownElementNodeParser(llm=MockLLM())
 
     nodes = node_parser.get_nodes_from_documents([test_data])
-    assert len(nodes) == 208
+    assert len(nodes) == 224
-- 
GitLab