diff --git a/llama-index-core/llama_index/core/node_parser/__init__.py b/llama-index-core/llama_index/core/node_parser/__init__.py index 13a3c11c4972b49ef42f12a8c32bff8a8fe23486..0f24dd07efae5586b771b1239b8280c1c04a3317 100644 --- a/llama-index-core/llama_index/core/node_parser/__init__.py +++ b/llama-index-core/llama_index/core/node_parser/__init__.py @@ -22,6 +22,9 @@ from llama_index.core.node_parser.relational.markdown_element import ( from llama_index.core.node_parser.relational.unstructured_element import ( UnstructuredElementNodeParser, ) +from llama_index.core.node_parser.relational.llama_parse_json_element import ( + LlamaParseJsonNodeParser, +) from llama_index.core.node_parser.text.code import CodeSplitter from llama_index.core.node_parser.text.langchain import LangchainNodeParser from llama_index.core.node_parser.text.semantic_splitter import ( @@ -57,6 +60,7 @@ __all__ = [ "get_root_nodes", "get_child_nodes", "get_deeper_nodes", + "LlamaParseJsonNodeParser", # deprecated, for backwards compatibility "SimpleNodeParser", ] diff --git a/llama-index-core/llama_index/core/node_parser/relational/__init__.py b/llama-index-core/llama_index/core/node_parser/relational/__init__.py index 405faa69e028cbbffe2bc65500e51833084c5a4c..4844d0fc0aa7073d93d4e4a6984367022c439b43 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/__init__.py +++ b/llama-index-core/llama_index/core/node_parser/relational/__init__.py @@ -7,9 +7,13 @@ from llama_index.core.node_parser.relational.markdown_element import ( from llama_index.core.node_parser.relational.unstructured_element import ( UnstructuredElementNodeParser, ) +from llama_index.core.node_parser.relational.llama_parse_json_element import ( + LlamaParseJsonNodeParser, +) __all__ = [ "HierarchicalNodeParser", "MarkdownElementNodeParser", "UnstructuredElementNodeParser", + "LlamaParseJsonNodeParser", ] diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py index 9a11ffd7a06b58d3911eb59b015ef5d41930dcc5..12f72887d8ed70cc4ce6e6be1e8794922c6d5c64 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py @@ -54,6 +54,8 @@ class Element(BaseModel): title_level: Optional[int] = None table_output: Optional[TableOutput] = None table: Optional[pd.DataFrame] = None + markdown: Optional[str] = None + page_number: Optional[int] = None class Config: arbitrary_types_allowed = True diff --git a/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py b/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py new file mode 100644 index 0000000000000000000000000000000000000000..c92c7534ebd1e4597460ed069a84fa9692569cc9 --- /dev/null +++ b/llama-index-core/llama_index/core/node_parser/relational/llama_parse_json_element.py @@ -0,0 +1,275 @@ +from typing import Any, Callable, List, Optional, Dict + +from llama_index.core.node_parser.relational.base_element import ( + BaseElementNodeParser, + Element, +) +from llama_index.core.schema import BaseNode, TextNode +from llama_index.core.node_parser.relational.utils import md_to_df + + +class LlamaParseJsonNodeParser(BaseElementNodeParser): + """Llama Parse Json format element node parser. + + Splits a json format document from LlamaParse into Text Nodes and Index Nodes + corresponding to embedded objects (e.g. tables). + + """ + + @classmethod + def class_name(cls) -> str: + return "LlamaParseJsonNodeParser" + + def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]: + """Get nodes from node.""" + elements = self.extract_elements( + node.get_content(), + table_filters=[self.filter_table], + node_id=node.id_, + node_metadata=node.metadata, + ) + table_elements = self.get_table_elements(elements) + # extract summaries over table elements + self.extract_table_summaries(table_elements) + # convert into nodes + # will return a list of Nodes and Index Nodes + return self.get_nodes_from_elements(elements, node.metadata) + + def extract_elements( + self, + text: str, + mode: Optional[str] = "json", + node_id: Optional[str] = None, + node_metadata: Optional[Dict[str, Any]] = None, + table_filters: Optional[List[Callable]] = None, + **kwargs: Any, + ) -> List[Element]: + # get node id for each node so that we can avoid using the same id for different nodes + """Extract elements from json based nodes. + + Args: + text: node's text content + mode: different modes for returning different types of elements based on the selected mode + node_id: unique id for the node + node_metadata: metadata for the node. the json output for the nodes contains a lot of fields for elements + + """ + elements: List[Element] = [] + currentElement = None + page_number = node_metadata.get("page") + + if mode == "json" and node_metadata is not None: + json_items = node_metadata.get("items") or [] + for element_idx, json_item in enumerate(json_items): + ele_type = json_item.get("type") + if ele_type == "heading": + elements.append( + Element( + id=f"id_page_{page_number}_heading_{element_idx}", + type="heading", + title_level=json_item.get("lvl"), + element=json_item.get("value"), + markdown=json_item.get("md"), + page_number=page_number, + ) + ) + elif ele_type == "text": + elements.append( + Element( + id=f"id_page_{page_number}_text_{element_idx}", + type="text", + element=json_item.get("value"), + markdown=json_item.get("md"), + page_number=page_number, + ) + ) + elif ele_type == "table": + elements.append( + Element( + id=f"id_page_{page_number}_table_{element_idx}", + type="table", + element=json_item.get("rows"), + markdown=json_item.get("md"), + page_number=page_number, + ) + ) + elif mode == "images" and node_metadata is not None: + # only get images from json metadata + images = node_metadata.get("images") or [] + for idx, image in enumerate(images): + elements.append( + Element( + id=f"id_page_{page_number}_image_{idx}", + type="image", + element=image, + ) + ) + else: + lines = text.split("\n") + # Then parse the lines from raw text of json + for line in lines: + if line.startswith("```"): + # check if this is the end of a code block + if currentElement is not None and currentElement.type == "code": + elements.append(currentElement) + currentElement = None + # if there is some text after the ``` create a text element with it + if len(line) > 3: + elements.append( + Element( + id=f"id_{len(elements)}", + type="text", + element=line.lstrip("```"), + ) + ) + + elif line.count("```") == 2 and line[-3] != "`": + # check if inline code block (aka have a second ``` in line but not at the end) + if currentElement is not None: + elements.append(currentElement) + currentElement = Element( + id=f"id_{len(elements)}", + type="code", + element=line.lstrip("```"), + ) + elif currentElement is not None and currentElement.type == "text": + currentElement.element += "\n" + line + else: + if currentElement is not None: + elements.append(currentElement) + currentElement = Element( + id=f"id_{len(elements)}", type="text", element=line + ) + + elif currentElement is not None and currentElement.type == "code": + currentElement.element += "\n" + line + + elif line.startswith("|"): + if currentElement is not None and currentElement.type != "table": + if currentElement is not None: + elements.append(currentElement) + currentElement = Element( + id=f"id_{len(elements)}", type="table", element=line + ) + elif currentElement is not None: + currentElement.element += "\n" + line + else: + currentElement = Element( + id=f"id_{len(elements)}", type="table", element=line + ) + elif line.startswith("#"): + if currentElement is not None: + elements.append(currentElement) + currentElement = Element( + id=f"id_{len(elements)}", + type="title", + element=line.lstrip("#"), + title_level=len(line) - len(line.lstrip("#")), + ) + else: + if currentElement is not None and currentElement.type != "text": + elements.append(currentElement) + currentElement = Element( + id=f"id_{len(elements)}", type="text", element=line + ) + elif currentElement is not None: + currentElement.element += "\n" + line + else: + currentElement = Element( + id=f"id_{len(elements)}", type="text", element=line + ) + if currentElement is not None: + elements.append(currentElement) + + for idx, element in enumerate(elements): + if element.type == "table": + should_keep = True + perfect_table = True + + # verify that the table (markdown) have the same number of columns on each rows + table_lines = element.markdown.split("\n") + table_columns = [len(line.split("|")) for line in table_lines] + if len(set(table_columns)) > 1: + # if the table have different number of columns on each rows, it's not a perfect table + # we will store the raw text for such tables instead of converting them to a dataframe + perfect_table = False + + # verify that the table (markdown) have at least 2 rows + if len(table_lines) < 2: + should_keep = False + + # apply the table filter, now only filter empty tables + if should_keep and perfect_table and table_filters is not None: + should_keep = all(tf(element) for tf in table_filters) + + # if the element is a table, convert it to a dataframe + if should_keep: + if perfect_table: + table = md_to_df(element.markdown) + + elements[idx] = Element( + id=( + f"id_page_{page_number}_{node_id}_{idx}" + if node_id + else f"id_{idx}" + ), + type="table", + element=element, + table=table, + ) + else: + # for non-perfect tables, we will store the raw text + # and give it a different type to differentiate it from perfect tables + elements[idx] = Element( + id=( + f"id_page_{page_number}_{node_id}_{idx}" + if node_id + else f"id_{idx}" + ), + type="table_text", + element=element.element, + # table=table + ) + else: + elements[idx] = Element( + id=( + f"id_page_{page_number}_{node_id}_{idx}" + if node_id + else f"id_page_{page_number}_{idx}" + ), + type="text", + element=element.element, + ) + else: + # if the element is not a table, keep it as to text + elements[idx] = Element( + id=( + f"id_page_{page_number}_{node_id}_{idx}" + if node_id + else f"id_page_{page_number}_{idx}" + ), + type="text", + element=element.element, + ) + + # merge consecutive text elements together for now + merged_elements: List[Element] = [] + for element in elements: + if ( + len(merged_elements) > 0 + and element.type == "text" + and merged_elements[-1].type == "text" + ): + merged_elements[-1].element += "\n" + element.element + else: + merged_elements.append(element) + elements = merged_elements + return merged_elements + + def filter_table(self, table_element: Any) -> bool: + """Filter tables.""" + # convert markdown of the table to df + table_df = md_to_df(table_element.markdown) + + # check if table_df is not None, has more than one row, and more than one column + return table_df is not None and not table_df.empty and len(table_df.columns) > 1 diff --git a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py index 4ee8fbedf1a65124a03855cf85053e57c576aadf..186e4b899388fcc8ed58df42f1f7311aa07a4849 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/markdown_element.py @@ -1,36 +1,11 @@ -from io import StringIO from typing import Any, Callable, List, Optional -import pandas as pd from llama_index.core.node_parser.relational.base_element import ( BaseElementNodeParser, Element, ) from llama_index.core.schema import BaseNode, TextNode - - -def md_to_df(md_str: str) -> pd.DataFrame: - """Convert Markdown to dataframe.""" - # Replace " by "" in md_str - md_str = md_str.replace('"', '""') - - # Replace markdown pipe tables with commas - md_str = md_str.replace("|", '","') - - # Remove the second line (table header separator) - lines = md_str.split("\n") - md_str = "\n".join(lines[:1] + lines[2:]) - - # Remove the first and last second char of the line (the pipes, transformed to ",") - lines = md_str.split("\n") - md_str = "\n".join([line[2:-2] for line in lines]) - - # Check if the table is empty - if len(md_str) == 0: - return None - - # Use pandas to read the CSV string into a DataFrame - return pd.read_csv(StringIO(md_str)) +from llama_index.core.node_parser.relational.utils import md_to_df class MarkdownElementNodeParser(BaseElementNodeParser): diff --git a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py index 9b29737cbfe4fb792c58082be198700ae4cbe1c9..2e759fe0f526c6cb2c24d5b2bcecc697fcc2ecc2 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/unstructured_element.py @@ -5,7 +5,6 @@ from typing import Any, Callable, List, Optional, Dict from llama_index.core.bridge.pydantic import Field -import pandas as pd from llama_index.core.callbacks.base import CallbackManager from llama_index.core.node_parser.relational.base_element import ( DEFAULT_SUMMARY_QUERY_STR, @@ -13,31 +12,7 @@ from llama_index.core.node_parser.relational.base_element import ( Element, ) from llama_index.core.schema import BaseNode, TextNode - - -def html_to_df(html_str: str) -> pd.DataFrame: - """Convert HTML to dataframe.""" - from lxml import html - - tree = html.fromstring(html_str) - table_element = tree.xpath("//table")[0] - rows = table_element.xpath(".//tr") - - data = [] - for row in rows: - cols = row.xpath(".//td") - cols = [c.text.strip() if c.text is not None else "" for c in cols] - data.append(cols) - - # Check if the table is empty - if len(data) == 0: - return None - - # Check if the all rows have the same number of columns - if not all(len(row) == len(data[0]) for row in data): - return None - - return pd.DataFrame(data[1:], columns=data[0]) +from llama_index.core.node_parser.relational.utils import html_to_df class UnstructuredElementNodeParser(BaseElementNodeParser): diff --git a/llama-index-core/llama_index/core/node_parser/relational/utils.py b/llama-index-core/llama_index/core/node_parser/relational/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29f53a854a858b79cc5edbdd13899397fc15d727 --- /dev/null +++ b/llama-index-core/llama_index/core/node_parser/relational/utils.py @@ -0,0 +1,56 @@ +import pandas as pd +from io import StringIO + + +def md_to_df(md_str: str) -> pd.DataFrame: + """Convert Markdown to dataframe.""" + # Replace " by "" in md_str + md_str = md_str.replace('"', '""') + + # Replace markdown pipe tables with commas + md_str = md_str.replace("|", '","') + + # Remove the second line (table header separator) + lines = md_str.split("\n") + md_str = "\n".join(lines[:1] + lines[2:]) + + # Remove the first and last second char of the line (the pipes, transformed to ",") + lines = md_str.split("\n") + md_str = "\n".join([line[2:-2] for line in lines]) + + # Check if the table is empty + if len(md_str) == 0: + return None + + # Use pandas to read the CSV string into a DataFrame + return pd.read_csv(StringIO(md_str)) + + +def html_to_df(html_str: str) -> pd.DataFrame: + """Convert HTML to dataframe.""" + try: + from lxml import html + except ImportError: + raise ImportError( + "You must install the `lxml` package to use this node parser." + ) + + tree = html.fromstring(html_str) + table_element = tree.xpath("//table")[0] + rows = table_element.xpath(".//tr") + + data = [] + for row in rows: + cols = row.xpath(".//td") + cols = [c.text.strip() if c.text is not None else "" for c in cols] + data.append(cols) + + # Check if the table is empty + if len(data) == 0: + return None + + # Check if the all rows have the same number of columns + if not all(len(row) == len(data[0]) for row in data): + return None + + return pd.DataFrame(data[1:], columns=data[0])