From 004620b86a6930c8ace25e4566b012e78ef64f02 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:00:35 +0200 Subject: [PATCH] feat: Unstructured splitter --- .../splitters/unstructured_splitter.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 semantic_router/splitters/unstructured_splitter.py diff --git a/semantic_router/splitters/unstructured_splitter.py b/semantic_router/splitters/unstructured_splitter.py new file mode 100644 index 00000000..4a0d69b1 --- /dev/null +++ b/semantic_router/splitters/unstructured_splitter.py @@ -0,0 +1,129 @@ +import re +from typing import Any + +from colorama import Fore, Style +from semantic_router.splitters import RollingWindowSplitter +from semantic_router.encoders import BaseEncoder + + +class UnstructuredSemanticSplitter: + def __init__( + self, + encoder: BaseEncoder, + window_size: int, + min_split_tokens: int, + max_split_tokens: int, + ): + self.splitter = RollingWindowSplitter( + encoder=encoder, + window_size=window_size, + min_split_tokens=min_split_tokens, + max_split_tokens=max_split_tokens, + ) + + def is_valid_title(self, title: str) -> bool: + # Rule 1: Title starts with a lowercase letter + if re.match(r"^[a-z]", title): + return False + # Rule 2: Title has a special character (excluding :, -, and .) + if re.search(r"[^\w\s:\-\.]", title): + return False + # Rule 3: Title ends with a dot + if title.endswith("."): + return False + return True + + def _group_elements_by_title(self, elements: list[dict[str, Any]]) -> dict: + grouped_elements = {} + current_title = "Untitled" # Default title for initial text without a title + + for element in elements: + if element.get("type") == "Title": + potential_title = element.get("text", "Untitled") + if self.is_valid_title(potential_title): + print(f"{Fore.GREEN}{potential_title}: True{Style.RESET_ALL}") + current_title = potential_title + else: + print(f"{Fore.RED}{potential_title}: False{Style.RESET_ALL}") + continue + else: + if current_title not in grouped_elements: + grouped_elements[current_title] = [] + else: + grouped_elements[current_title].append(element) + return grouped_elements + + async def split_grouped_elements( + self, elements: list[dict[str, Any]], splitter: RollingWindowSplitter + ) -> list[dict[str, Any]]: + grouped_elements = self._group_elements_by_title(elements) + chunks_with_title = [] + + def _append_chunks(*, title: str, content: str, index: int, metadata: dict): + chunks_with_title.append( + { + "title": title, + "content": content, + "chunk_index": index, + "metadata": metadata, + } + ) + + for index, (title, elements) in enumerate(grouped_elements.items()): + if not elements: + continue + section_metadata = elements[0].get( + "metadata", {} + ) # Took first element's data + accumulated_element_texts: list[str] = [] + chunks: list[dict[str, Any]] = [] + + for element in elements: + if not element.get("text"): + continue + if element.get("type") == "Table": + # Process accumulated text before the table + if accumulated_element_texts: + splits = splitter(accumulated_element_texts) + for split in splits: + _append_chunks( + title=title, + content=split.content, + index=index, + metadata=section_metadata, + ) + # TODO: reset after PageBreak also + accumulated_element_texts = ( + [] + ) # Start new accumulation after table + + # Add table as a separate chunk + _append_chunks( + title=title, + content=element.get("metadata", {}).get( + "text_as_html", "No text" + ), + index=index, + metadata=element.get("metadata", {}), + ) + else: + accumulated_element_texts.append(element.get("text", "No text")) + + # Process any remaining accumulated text after the last table + # or if no table was encountered + + if accumulated_element_texts: + splits = splitter(accumulated_element_texts) + for split in splits: + _append_chunks( + title=title, + content=split.content, + index=index, + metadata=section_metadata, + ) + if chunks: + chunks_with_title.extend(chunks) + return chunks_with_title + + async def __call__(self, elements: list[dict[str, Any]]) -> list[dict[str, Any]]: + return await self.split_grouped_elements(elements, self.splitter) \ No newline at end of file -- GitLab