From 2fd9a0295dc9c427da57f5f74a2a978797ac00cc Mon Sep 17 00:00:00 2001 From: Haotian Zhang <socool.king@gmail.com> Date: Wed, 6 Mar 2024 16:49:52 -0500 Subject: [PATCH] Make other nodes parsing for element parser configurable (#11717) --- .../core/node_parser/relational/base_element.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/llama-index-core/llama_index/core/node_parser/relational/base_element.py b/llama-index-core/llama_index/core/node_parser/relational/base_element.py index f14809e149..17ea964ac8 100644 --- a/llama-index-core/llama_index/core/node_parser/relational/base_element.py +++ b/llama-index-core/llama_index/core/node_parser/relational/base_element.py @@ -13,6 +13,7 @@ from llama_index.core.llms.llm import LLM from llama_index.core.node_parser.interface import NodeParser from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode from llama_index.core.utils import get_tqdm_iterable +from llama_index.core.node_parser import SentenceSplitter DEFAULT_SUMMARY_QUERY_STR = """\ What is this table about? Give a very concise summary (imagine you are adding a new caption and summary for this table), \ @@ -78,14 +79,19 @@ class BaseElementNodeParser(NodeParser): ) num_workers: int = Field( default=DEFAULT_NUM_WORKERS, - description="Num of works for async jobs.", + description="Num of workers for async jobs.", ) show_progress: bool = Field(default=True, description="Whether to show progress.") + nested_node_parser: Optional[NodeParser] = Field( + default=None, + description="Other types of node parsers to handle some types of nodes.", + ) + @classmethod def class_name(cls) -> str: - return "BaseStructuredNodeParser" + return "BaseElementNodeParser" @classmethod def from_defaults( @@ -255,9 +261,7 @@ class BaseElementNodeParser(NodeParser): metadata_inherited: Optional[Dict[str, Any]] = None, ) -> List[BaseNode]: """Get nodes and mappings.""" - from llama_index.core.node_parser import SentenceSplitter - - node_parser = SentenceSplitter() + node_parser = self.nested_node_parser or SentenceSplitter() nodes = [] cur_text_el_buffer: List[str] = [] @@ -338,7 +342,8 @@ class BaseElementNodeParser(NodeParser): nodes.extend([index_node, text_node]) else: cur_text_el_buffer.append(str(element.element)) - # flush text buffer + + # flush text buffer for the last batch if len(cur_text_el_buffer) > 0: cur_text_nodes = self._get_nodes_from_buffer( cur_text_el_buffer, node_parser -- GitLab