From 2fd9a0295dc9c427da57f5f74a2a978797ac00cc Mon Sep 17 00:00:00 2001
From: Haotian Zhang <socool.king@gmail.com>
Date: Wed, 6 Mar 2024 16:49:52 -0500
Subject: [PATCH] Make other nodes parsing for element parser configurable
 (#11717)

---
 .../core/node_parser/relational/base_element.py | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
index f14809e149..17ea964ac8 100644
--- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py
+++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py
@@ -13,6 +13,7 @@ from llama_index.core.llms.llm import LLM
 from llama_index.core.node_parser.interface import NodeParser
 from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode
 from llama_index.core.utils import get_tqdm_iterable
+from llama_index.core.node_parser import SentenceSplitter
 
 DEFAULT_SUMMARY_QUERY_STR = """\
 What is this table about? Give a very concise summary (imagine you are adding a new caption and summary for this table), \
@@ -78,14 +79,19 @@ class BaseElementNodeParser(NodeParser):
     )
     num_workers: int = Field(
         default=DEFAULT_NUM_WORKERS,
-        description="Num of works for async jobs.",
+        description="Num of workers for async jobs.",
     )
 
     show_progress: bool = Field(default=True, description="Whether to show progress.")
 
+    nested_node_parser: Optional[NodeParser] = Field(
+        default=None,
+        description="Other types of node parsers to handle some types of nodes.",
+    )
+
     @classmethod
     def class_name(cls) -> str:
-        return "BaseStructuredNodeParser"
+        return "BaseElementNodeParser"
 
     @classmethod
     def from_defaults(
@@ -255,9 +261,7 @@ class BaseElementNodeParser(NodeParser):
         metadata_inherited: Optional[Dict[str, Any]] = None,
     ) -> List[BaseNode]:
         """Get nodes and mappings."""
-        from llama_index.core.node_parser import SentenceSplitter
-
-        node_parser = SentenceSplitter()
+        node_parser = self.nested_node_parser or SentenceSplitter()
 
         nodes = []
         cur_text_el_buffer: List[str] = []
@@ -338,7 +342,8 @@ class BaseElementNodeParser(NodeParser):
                 nodes.extend([index_node, text_node])
             else:
                 cur_text_el_buffer.append(str(element.element))
-        # flush text buffer
+
+        # flush text buffer for the last batch
         if len(cur_text_el_buffer) > 0:
             cur_text_nodes = self._get_nodes_from_buffer(
                 cur_text_el_buffer, node_parser
-- 
GitLab