diff --git a/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/docs/base.py b/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/docs/base.py index 253b91b9df63d5cb2d84f83f5503cd3e38c6b7c1..affac80454f79f9bb83374f232fb0fa7da6564ad 100644 --- a/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/docs/base.py +++ b/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/docs/base.py @@ -2,10 +2,14 @@ import logging import os -from typing import Any, List +import random +import string +from typing import Any, List, Optional import googleapiclient.discovery as discovery from google_auth_oauthlib.flow import InstalledAppFlow + +from llama_index.core.bridge.pydantic import Field from llama_index.core.readers.base import BasePydanticReader from llama_index.core.schema import Document @@ -40,8 +44,14 @@ class GoogleDocsReader(BasePydanticReader): is_remote: bool = True - def __init__(self) -> None: - """Initialize with parameters.""" + split_on_heading_level: Optional[int] = Field( + default=None, + description="If set the document will be split on the specified heading level.", + ) + + include_toc: bool = Field( + default=True, description="Include table of contents elements." + ) @classmethod def class_name(cls) -> str: @@ -58,12 +68,9 @@ class GoogleDocsReader(BasePydanticReader): results = [] for document_id in document_ids: - doc = self._load_doc(document_id) - results.append( - Document( - text=doc, id_=document_id, metadata={"document_id": document_id} - ) - ) + docs = self._load_doc(document_id) + results.extend(docs) + return results def _load_doc(self, document_id: str) -> str: @@ -77,9 +84,12 @@ class GoogleDocsReader(BasePydanticReader): """ credentials = self._get_credentials() docs_service = discovery.build("docs", "v1", credentials=credentials) - doc = docs_service.documents().get(documentId=document_id).execute() - doc_content = doc.get("body").get("content") - return self._read_structural_elements(doc_content) + google_doc = docs_service.documents().get(documentId=document_id).execute() + google_doc_content = google_doc.get("body").get("content") + + doc_metadata = {"document_id": document_id} + + return self._structural_elements_to_docs(google_doc_content, doc_metadata) def _get_credentials(self) -> Any: """Get valid user credentials from storage. @@ -148,9 +158,95 @@ class GoogleDocsReader(BasePydanticReader): text += self._read_structural_elements(toc.get("content")) return text + def _determine_heading_level(self, element): + """Extracts the heading level, label, and ID from a document element. + + Args: + element: a Structural Element. + """ + level = None + heading_key = None + heading_id = None + if self.split_on_heading_level and "paragraph" in element: + style = element.get("paragraph").get("paragraphStyle") + style_type = style.get("namedStyleType", "") + heading_id = style.get("headingId", None) + if style_type == "TITLE": + level = 0 + heading_key = "title" + elif style_type.startswith("HEADING_"): + level = int(style_type.split("_")[1]) + if level > self.split_on_heading_level: + return None, None, None + + heading_key = f"Header {level}" + + return level, heading_key, heading_id + + def _generate_doc_id(self, metadata: dict): + if "heading_id" in metadata: + heading_id = metadata["heading_id"] + else: + heading_id = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + return f"{metadata['document_id']}_{heading_id}" + + def _structural_elements_to_docs( + self, elements: List[Any], doc_metadata: dict + ) -> Any: + """Recurse through a list of Structural Elements. + + Split documents on heading if split_on_heading_level is set. + + Args: + elements: a list of Structural Elements. + """ + docs = [] + + current_heading_level = self.split_on_heading_level + + metadata = doc_metadata.copy() + text = "" + for value in elements: + element_text = self._read_structural_elements([value]) + + level, heading_key, heading_id = self._determine_heading_level(value) + + if level is not None: + if level == self.split_on_heading_level: + if text.strip(): + docs.append( + Document( + id_=self._generate_doc_id(metadata), + text=text, + metadata=metadata.copy(), + ) + ) + text = "" + if "heading_id" in metadata: + metadata["heading_id"] = heading_id + elif level < current_heading_level: + metadata = doc_metadata.copy() + + metadata[heading_key] = element_text + current_heading_level = level + else: + text += element_text + + if text: + if docs: + id_ = self._generate_doc_id(metadata) + else: + id_ = metadata["document_id"] + docs.append(Document(id_=id_, text=text, metadata=metadata)) + + return docs + if __name__ == "__main__": - reader = GoogleDocsReader() - logger.info( - reader.load_data(document_ids=["11ctUj_tEf5S8vs_dk8_BNi-Zk8wW5YFhXkKqtmU_4B8"]) + reader = GoogleDocsReader(split_on_heading_level=1) + docs = reader.load_data( + document_ids=["1UORoHYBKmOdcv4g94znMF0ildBYWiu3C2M2MEsWN4mM"] ) + logger.info(docs)