diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py index f14809e1496eaf7a6927145fb3b21db36d9bbf8f..17ea964ac8b73f483e0d1bcee24da7008725992e 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py @@ -13,6 +13,7 @@ from llama_index.core.llms.llm import LLM from llama_index.core.node_parser.interface import NodeParser from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode from llama_index.core.utils import get_tqdm_iterable +from llama_index.core.node_parser import SentenceSplitter DEFAULT_SUMMARY_QUERY_STR = """\ What is this table about? Give a very concise summary (imagine you are adding a new caption and summary for this table), \ @@ -78,14 +79,19 @@ class BaseElementNodeParser(NodeParser): ) num_workers: int = Field( default=DEFAULT_NUM_WORKERS, - description="Num of works for async jobs.", + description="Num of workers for async jobs.", ) show_progress: bool = Field(default=True, description="Whether to show progress.") + nested_node_parser: Optional[NodeParser] = Field( + default=None, + description="Other types of node parsers to handle some types of nodes.", + ) + @classmethod def class_name(cls) -> str: - return "BaseStructuredNodeParser" + return "BaseElementNodeParser" @classmethod def from_defaults( @@ -255,9 +261,7 @@ class BaseElementNodeParser(NodeParser): metadata_inherited: Optional[Dict[str, Any]] = None, ) -> List[BaseNode]: """Get nodes and mappings.""" - from llama_index.core.node_parser import SentenceSplitter - - node_parser = SentenceSplitter() + node_parser = self.nested_node_parser or SentenceSplitter() nodes = [] cur_text_el_buffer: List[str] = [] @@ -338,7 +342,8 @@ class BaseElementNodeParser(NodeParser): nodes.extend([index_node, text_node]) else: cur_text_el_buffer.append(str(element.element)) - # flush text buffer + + # flush text buffer for the last batch if len(cur_text_el_buffer) > 0: cur_text_nodes = self._get_nodes_from_buffer( cur_text_el_buffer, node_parser