From e6daed5fdb36602f5345c23ee7403e2e67b915c5 Mon Sep 17 00:00:00 2001 From: Javier Torres <javierandrestorresreyes@gmail.com> Date: Fri, 17 May 2024 13:15:42 -0500 Subject: [PATCH] Implement BaseFilesystemReader in S3Reader and SharePoint Reader (#13408) * interface+simpledirreader * implement for s3 * implement list files for sharepoint * sharepoint * changes core * baseresourcesreader * test_file * s3 changes * version bump * sharepoint changes * error handling * use mixins * use mixins --- .../llama-index-readers-file/pyproject.toml | 2 +- .../tests/test_file.py | 66 ++++- .../readers/microsoft_sharepoint/base.py | 259 +++++++++++++++++- .../pyproject.toml | 4 +- .../llama_index/readers/s3/base.py | 82 +++++- .../llama-index-readers-s3/pyproject.toml | 6 +- .../tests/test_readers_s3.py | 66 ++++- 7 files changed, 456 insertions(+), 29 deletions(-) diff --git a/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml index f3ed4a6481..d3d1f60c22 100644 --- a/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml @@ -54,7 +54,7 @@ version = "0.1.22" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -llama-index-core = "^0.10.1" +llama-index-core = "^0.10.37.post1" # pymupdf is AGPLv3-licensed, so it's optional pymupdf = {optional = true, version = "^1.23.21"} beautifulsoup4 = "^4.12.3" diff --git a/llama-index-integrations/readers/llama-index-readers-file/tests/test_file.py b/llama-index-integrations/readers/llama-index-readers-file/tests/test_file.py index 7a0c116fe7..8d6b500c50 100644 --- a/llama-index-integrations/readers/llama-index-readers-file/tests/test_file.py +++ b/llama-index-integrations/readers/llama-index-readers-file/tests/test_file.py @@ -2,10 +2,12 @@ from multiprocessing import cpu_count from tempfile import TemporaryDirectory -from typing import Any, Dict +from typing import Any, Dict, List +import hashlib import pytest from llama_index.core.readers.file.base import SimpleDirectoryReader +from llama_index.core.schema import Document try: from llama_index.readers.file import PDFReader @@ -472,3 +474,65 @@ def test_parallel_load() -> None: # check paths. Split handles path_part_X doc_ids from md and json files for doc in documents: assert str(doc.node_id).split("_part")[0] in doc_paths + + +def _compare_document_lists( + documents1: List[Document], documents2: List[Document] +) -> None: + assert len(documents1) == len(documents2) + hashes_1 = {doc.hash for doc in documents1} + hashes_2 = {doc.hash for doc in documents2} + assert hashes_1 == hashes_2 + + +@pytest.mark.skipif(PDFReader is None, reason="llama-index-readers-file not installed") +def test_list_and_read_file_workflow() -> None: + with TemporaryDirectory() as tmp_dir: + with open(f"{tmp_dir}/test1.txt", "w") as f: + f.write("test1") + with open(f"{tmp_dir}/test2.txt", "w") as f: + f.write("test2") + + reader = SimpleDirectoryReader(tmp_dir) + original_docs = reader.load_data() + + files = reader.list_resources() + assert len(files) == 2 + + new_docs: List[Document] = [] + for file in files: + file_info = reader.get_resource_info(file) + assert file_info is not None + assert len(file_info) == 4 + + new_docs.extend(reader.load_resource(file)) + + _compare_document_lists(original_docs, new_docs) + + new_docs = reader.load_resources(files) + _compare_document_lists(original_docs, new_docs) + + +@pytest.mark.skipif(PDFReader is None, reason="llama-index-readers-file not installed") +def test_read_file_content() -> None: + with TemporaryDirectory() as tmp_dir: + with open(f"{tmp_dir}/test1.txt", "w") as f: + f.write("test1") + with open(f"{tmp_dir}/test2.txt", "w") as f: + f.write("test2") + + files_checksum = { + f"{tmp_dir}/test1.txt": hashlib.md5( + open(f"{tmp_dir}/test1.txt", "rb").read() + ).hexdigest(), + f"{tmp_dir}/test2.txt": hashlib.md5( + open(f"{tmp_dir}/test2.txt", "rb").read() + ).hexdigest(), + } + + reader = SimpleDirectoryReader(tmp_dir) + + for file in files_checksum: + content = reader.read_file_content(file) + checksum = hashlib.md5(content).hexdigest() + assert checksum == files_checksum[file] diff --git a/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/llama_index/readers/microsoft_sharepoint/base.py b/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/llama_index/readers/microsoft_sharepoint/base.py index a4de5a20c8..5d30f93ba3 100644 --- a/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/llama_index/readers/microsoft_sharepoint/base.py +++ b/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/llama_index/readers/microsoft_sharepoint/base.py @@ -2,21 +2,27 @@ import logging import os +from pathlib import Path import tempfile from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Optional import requests -from llama_index.core.readers import SimpleDirectoryReader -from llama_index.core.readers.base import BaseReader, BasePydanticReader +from llama_index.core.readers import SimpleDirectoryReader, FileSystemReaderMixin +from llama_index.core.readers.base import ( + BaseReader, + BasePydanticReader, + ResourcesReaderMixin, +) from llama_index.core.schema import Document from llama_index.core.bridge.pydantic import PrivateAttr, Field logger = logging.getLogger(__name__) -class SharePointReader(BasePydanticReader): - """SharePoint reader. +class SharePointReader(BasePydanticReader, ResourcesReaderMixin, FileSystemReaderMixin): + """ + SharePoint reader. Reads folders from the SharePoint site from a folder under documents. @@ -121,6 +127,9 @@ class SharePointReader(BasePydanticReader): Raises: Exception: If the specified SharePoint site is not found. """ + if hasattr(self, "_site_id_with_host_name"): + return self._site_id_with_host_name + site_information_endpoint = ( f"https://graph.microsoft.com/v1.0/sites?search={sharepoint_site_name}" ) @@ -157,6 +166,9 @@ class SharePointReader(BasePydanticReader): Raises: ValueError: If there is an error in obtaining the drive ID. """ + if hasattr(self, "_drive_id"): + return self._drive_id + self._drive_id_endpoint = f"https://graph.microsoft.com/v1.0/sites/{self._site_id_with_host_name}/drives" response = requests.get( @@ -259,6 +271,24 @@ class SharePointReader(BasePydanticReader): logger.error(response.json()["error"]) raise ValueError(response.json()["error"]) + def _get_file_content_by_url(self, item: Dict[str, Any]) -> bytes: + """ + Retrieves the content of the file from the provided URL. + + Args: + item (Dict[str, Any]): Dictionary containing file metadata. + + Returns: + bytes: The content of the file. + """ + file_download_url = item["@microsoft.graph.downloadUrl"] + response = requests.get(file_download_url) + if response.status_code != 200: + logger.error(response.json()["error"]) + raise ValueError(response.json()["error_description"]) + + return response.content + def _download_file_by_url(self, item: Dict[str, Any], download_dir: str) -> str: """ Downloads the file from the provided URL. @@ -271,17 +301,16 @@ class SharePointReader(BasePydanticReader): str: The path of the downloaded file in the temporary directory. """ # Get the download URL for the file. - file_download_url = item["@microsoft.graph.downloadUrl"] file_name = item["name"] - response = requests.get(file_download_url) + content = self._get_file_content_by_url(item) # Create the directory if it does not exist and save the file. if not os.path.exists(download_dir): os.makedirs(download_dir) file_path = os.path.join(download_dir, file_name) with open(file_path, "wb") as f: - f.write(response.content) + f.write(content) return file_path @@ -337,6 +366,13 @@ class SharePointReader(BasePydanticReader): permissions_dict[ids_key].append(id) permissions_dict[display_names_key].append(display_name) + # sort to get consistent results, if possible + for key in permissions_dict: + try: + permissions_dict[key] = sorted(permissions_dict[key]) + except TypeError: + pass + return permissions_dict def _extract_metadata_for_file(self, item: Dict[str, Any]) -> Dict[str, str]: @@ -531,3 +567,212 @@ class SharePointReader(BasePydanticReader): except Exception as exp: logger.error("An error occurred while accessing SharePoint: %s", exp) + + def _list_folder_contents( + self, folder_id: str, recursive: bool, current_path: str + ) -> List[Path]: + """ + Helper method to fetch the contents of a folder. + + Args: + folder_id (str): ID of the folder whose contents are to be listed. + recursive (bool): Whether to include subfolders recursively. + + Returns: + List[Path]: List of file paths. + """ + folder_contents_endpoint = ( + f"{self._drive_id_endpoint}/{self._drive_id}/items/{folder_id}/children" + ) + response = requests.get( + url=folder_contents_endpoint, + headers=self._authorization_headers, + ) + items = response.json().get("value", []) + + file_paths = [] + for item in items: + if "folder" in item and recursive: + # Recursive call for subfolder + subfolder_id = item["id"] + subfolder_paths = self._list_folder_contents( + subfolder_id, recursive, os.path.join(current_path, item["name"]) + ) + file_paths.extend(subfolder_paths) + elif "file" in item: + # Append file path + file_path = Path(os.path.join(current_path, item["name"])) + file_paths.append(file_path) + + return file_paths + + def list_resources( + self, + sharepoint_site_name: Optional[str] = None, + sharepoint_folder_path: Optional[str] = None, + sharepoint_folder_id: Optional[str] = None, + recursive: bool = True, + ) -> List[Path]: + """ + Lists the files in the specified folder in the SharePoint site. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + List[Path]: A list of paths of the files in the specified folder. + + Raises: + Exception: If an error occurs while accessing SharePoint site. + """ + # If no arguments are provided to load_data, default to the object attributes + if not sharepoint_site_name: + sharepoint_site_name = self.sharepoint_site_name + + if not sharepoint_folder_path: + sharepoint_folder_path = self.sharepoint_folder_path + + if not sharepoint_folder_id: + sharepoint_folder_id = self.sharepoint_folder_id + + # TODO: make both of these values optional — and just default to the client ID defaults + if not sharepoint_site_name: + raise ValueError("sharepoint_site_name must be provided.") + + if not sharepoint_folder_path and not sharepoint_folder_id: + raise ValueError( + "sharepoint_folder_path or sharepoint_folder_id must be provided." + ) + + file_paths = [] + try: + access_token = self._get_access_token() + self._site_id_with_host_name = self._get_site_id_with_host_name( + access_token, sharepoint_site_name + ) + self._drive_id = self._get_drive_id() + if not sharepoint_folder_id: + sharepoint_folder_id = self._get_sharepoint_folder_id( + sharepoint_folder_path + ) + + # Fetch folder contents + folder_contents = self._list_folder_contents( + sharepoint_folder_id, + recursive, + os.path.join(sharepoint_site_name, sharepoint_folder_path), + ) + file_paths.extend(folder_contents) + return file_paths + + except Exception as exp: + logger.error("An error occurred while listing files in SharePoint: %s", exp) + raise + + return file_paths + + def _get_item_from_path(self, input_file: Path) -> Dict[str, Any]: + """ + Retrieves the item details for a specified file in SharePoint. + + Args: + input_file (Path): The path of the file in SharePoint. + Should include the SharePoint site name and the folder path. e.g. "site_name/folder_path/file_name". + + Returns: + Dict[str, Any]: Dictionary containing the item details. + """ + # Get the file ID + # remove the site_name prefix + file_path = ( + str(input_file).lstrip("/").replace(f"{self.sharepoint_site_name}/", "", 1) + ) + endpoint = f"{self._drive_id_endpoint}/{self._drive_id}/root:/{file_path}" + + response = requests.get( + url=endpoint, + headers=self._authorization_headers, + ) + + return response.json() + + def get_resource_info(self, resource_id: str, **kwargs) -> Dict: + """ + Retrieves metadata for a specified file in SharePoint without downloading it. + + Args: + input_file (Path): The path of the file in SharePoint. The path should include + the SharePoint site name and the folder path. e.g. "site_name/folder_path/file_name". + """ + try: + item = self._get_item_from_path(Path(resource_id)) + + info_dict = { + "file_path": resource_id, + "size": item.get("size"), + "created_at": item.get("createdDateTime"), + "modified_at": item.get("lastModifiedDateTime"), + "etag": item.get("eTag"), + } + + if ( + self.attach_permission_metadata + ): # changes in access control should trigger a reingestion of the file + permissions = self._get_permissions_info(item) + info_dict.update(permissions) + + return { + meta_key: meta_value + for meta_key, meta_value in info_dict.items() + if meta_value is not None + } + + except Exception as exp: + logger.error( + "An error occurred while fetching file information from SharePoint: %s", + exp, + ) + raise + + def load_resource(self, resource_id: str, **kwargs) -> List[Document]: + try: + access_token = self._get_access_token() + self._site_id_with_host_name = self._get_site_id_with_host_name( + access_token, self.sharepoint_site_name + ) + self._drive_id = self._get_drive_id() + + path = Path(resource_id) + + item = self._get_item_from_path(path) + + input_file_dir = path.parent + + with tempfile.TemporaryDirectory() as temp_dir: + metadata = self._download_file(item, temp_dir, input_file_dir) + return self._load_documents_with_metadata( + metadata, temp_dir, recursive=False + ) + + except Exception as exp: + logger.error( + "An error occurred while reading file from SharePoint: %s", exp + ) + raise + + def read_file_content(self, input_file: Path, **kwargs) -> bytes: + try: + access_token = self._get_access_token() + self._site_id_with_host_name = self._get_site_id_with_host_name( + access_token, self.sharepoint_site_name + ) + self._drive_id = self._get_drive_id() + + item = self._get_item_from_path(input_file) + return self._get_file_content_by_url(item) + + except Exception as exp: + logger.error( + "An error occurred while reading file content from SharePoint: %s", exp + ) + raise diff --git a/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/pyproject.toml index 1ed6491498..f872bc5813 100644 --- a/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-microsoft-sharepoint/pyproject.toml @@ -29,11 +29,11 @@ license = "MIT" maintainers = ["arun-soliton"] name = "llama-index-readers-microsoft-sharepoint" readme = "README.md" -version = "0.2.2" +version = "0.2.3" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -llama-index-core = "^0.10.1" +llama-index-core = "^0.10.37.post1" requests = "^2.31.0" [tool.poetry.group.dev.dependencies] diff --git a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py index 67586e1658..ed381fbc05 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py +++ b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py @@ -7,14 +7,20 @@ A loader that fetches a file or iterates through a directory on AWS S3. import warnings from typing import Callable, Dict, List, Optional, Union - -from llama_index.core.readers import SimpleDirectoryReader -from llama_index.core.readers.base import BaseReader, BasePydanticReader +from datetime import datetime, timezone +from pathlib import Path + +from llama_index.core.readers import SimpleDirectoryReader, FileSystemReaderMixin +from llama_index.core.readers.base import ( + BaseReader, + BasePydanticReader, + ResourcesReaderMixin, +) from llama_index.core.schema import Document from llama_index.core.bridge.pydantic import Field -class S3Reader(BasePydanticReader): +class S3Reader(BasePydanticReader, ResourcesReaderMixin, FileSystemReaderMixin): """ General reader for any S3 file or directory. @@ -66,17 +72,20 @@ class S3Reader(BasePydanticReader): def class_name(cls) -> str: return "S3Reader" - def load_s3_files_as_docs(self, temp_dir=None) -> List[Document]: - """Load file(s) from S3.""" + def _get_s3fs(self): from s3fs import S3FileSystem - s3fs = S3FileSystem( + return S3FileSystem( key=self.aws_access_id, endpoint_url=self.s3_endpoint_url, secret=self.aws_access_secret, token=self.aws_session_token, ) + def _get_simple_directory_reader(self) -> SimpleDirectoryReader: + # we don't want to keep the reader as a field in the class to keep it serializable + s3fs = self._get_s3fs() + input_dir = self.bucket input_files = None @@ -85,7 +94,7 @@ class S3Reader(BasePydanticReader): elif self.prefix: input_dir = f"{input_dir}/{self.prefix}" - loader = SimpleDirectoryReader( + return SimpleDirectoryReader( input_dir=input_dir, input_files=input_files, file_extractor=self.file_extractor, @@ -97,8 +106,19 @@ class S3Reader(BasePydanticReader): fs=s3fs, ) + def load_s3_files_as_docs(self, temp_dir=None) -> List[Document]: + """Load file(s) from S3.""" + loader = self._get_simple_directory_reader() return loader.load_data() + def _adjust_documents(self, documents: List[Document]) -> List[Document]: + for doc in documents: + if self.s3_endpoint_url: + doc.id_ = self.s3_endpoint_url + "_" + doc.id_ + else: + doc.id_ = "s3_" + doc.id_ + return documents + def load_data(self, custom_temp_subdir: str = None) -> List[Document]: """ Load the file(s) from S3. @@ -116,10 +136,44 @@ class S3Reader(BasePydanticReader): ) documents = self.load_s3_files_as_docs() - for doc in documents: - if self.s3_endpoint_url: - doc.id_ = self.s3_endpoint_url + "_" + doc.id_ - else: - doc.id_ = "s3_" + doc.id_ + return self._adjust_documents(documents) - return documents + def list_resources(self, **kwargs) -> List[str]: + simple_directory_reader = self._get_simple_directory_reader() + return simple_directory_reader.list_resources(**kwargs) + + def get_resource_info(self, resource_id: str, **kwargs) -> Dict: + # can't use SimpleDirectoryReader.get_resource_info because it lacks some fields + fs = self._get_s3fs() + info_result = fs.info(resource_id) + + last_modified_date = info_result.get("LastModified") + if last_modified_date and isinstance(last_modified_date, datetime): + last_modified_date = last_modified_date.astimezone(timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + else: + last_modified_date = None + + info_dict = { + "file_path": str(resource_id), + "file_size": info_result.get("size"), + "last_modified_date": last_modified_date, + "content_hash": info_result.get("ETag"), + } + + # Ignore None values + return { + meta_key: meta_value + for meta_key, meta_value in info_dict.items() + if meta_value is not None + } + + def load_resource(self, resource_id: str, **kwargs) -> List[Document]: + simple_directory_reader = self._get_simple_directory_reader() + docs = simple_directory_reader.load_resource(resource_id, **kwargs) + return self._adjust_documents(docs) + + def read_file_content(self, input_file: Path, **kwargs) -> bytes: + simple_directory_reader = self._get_simple_directory_reader() + return simple_directory_reader.read_file_content(input_file, **kwargs) diff --git a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml index 999795ba71..46b1499abb 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml @@ -29,13 +29,13 @@ license = "MIT" maintainers = ["thejessezhang"] name = "llama-index-readers-s3" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -llama-index-core = "^0.10.1" +llama-index-core = "^0.10.37.post1" llama-index-readers-file = "^0.1.12" -s3fs = "^2024.3.1" +s3fs = ">=2024.3.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" diff --git a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py index 27f887eb0e..11c4c41fbd 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py +++ b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py @@ -1,10 +1,13 @@ from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document from llama_index.readers.s3 import S3Reader +from typing import List from moto.server import ThreadedMotoServer import pytest import os import requests from s3fs import S3FileSystem +import hashlib test_bucket = "test" files = [ @@ -40,7 +43,8 @@ def init_s3_files(s3_base): s3fs.mkdir(f"{test_bucket}/subdir") s3fs.mkdir(f"{test_bucket}/subdir2") for file in files: - s3fs.touch(file) + with s3fs.open(file, "w") as f: + f.write(f"test file: {file}") def test_class(): @@ -53,6 +57,8 @@ def test_load_all_files(init_s3_files): bucket=test_bucket, s3_endpoint_url=endpoint_url, ) + files = reader.list_resources() + assert len(files) == 3 documents = reader.load_data() assert len(documents) == len(files) @@ -63,6 +69,8 @@ def test_load_single_file(init_s3_files): key="test.txt", s3_endpoint_url=endpoint_url, ) + files = reader.list_resources() + assert len(files) == 1 documents = reader.load_data() assert len(documents) == 1 assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt" @@ -74,11 +82,17 @@ def test_load_with_prefix(init_s3_files): prefix="subdir", s3_endpoint_url=endpoint_url, ) + files = reader.list_resources() + assert len(files) == 1 + assert str(files[0]).startswith(f"{test_bucket}/subdir") documents = reader.load_data() assert len(documents) == 1 assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir/test2.txt" reader.prefix = "subdir2" + files = reader.list_resources() + assert len(files) == 1 + assert str(files[0]).startswith(f"{test_bucket}/subdir2") documents = reader.load_data() assert len(documents) == 1 assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir2/test3.txt" @@ -95,6 +109,56 @@ def test_load_not_recursive(init_s3_files): assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt" +def _compare_document_lists( + documents1: List[Document], documents2: List[Document] +) -> None: + assert len(documents1) == len(documents2) + hashes_1 = {doc.hash for doc in documents1} + hashes_2 = {doc.hash for doc in documents2} + assert hashes_1 == hashes_2 + + +def test_list_and_read_file_workflow(init_s3_files): + reader = S3Reader( + bucket=test_bucket, + s3_endpoint_url=endpoint_url, + ) + + original_docs = reader.load_data() + files = reader.list_resources() + new_docs: List[Document] = [] + for file in files: + file_info = reader.get_resource_info(file) + assert file_info is not None + assert len(file_info) == 4 + new_docs.extend(reader.load_resource(file)) + _compare_document_lists(original_docs, new_docs) + + new_docs = reader.load_resources(files) + _compare_document_lists(original_docs, new_docs) + + +def test_read_file_content(init_s3_files): + s3fs = S3FileSystem( + endpoint_url=endpoint_url, + ) + checksums = {} + for file in files: + with s3fs.open(file, "rb") as f: + content = f.read() + checksums[file] = hashlib.md5(content).hexdigest() + + reader = S3Reader( + bucket=test_bucket, + s3_endpoint_url=endpoint_url, + ) + + for file in files: + content = reader.read_file_content(file) + checksum = hashlib.md5(content).hexdigest() + assert checksum == checksums[file] + + def test_serialize(): reader = S3Reader( bucket=test_bucket, -- GitLab