From 181b32291744542bc8c55acf2edea73b4814b69d Mon Sep 17 00:00:00 2001
From: Javier Torres <javierandrestorresreyes@gmail.com>
Date: Wed, 20 Mar 2024 16:47:34 -0500
Subject: [PATCH] Use S3FS in S3Reader (#12061)

* use s3fs

* when key=none

* use s3fs

* changelog + version bump

* tests + serialization

* changelog

* pants update

* module mappings

* add files reader to BUILD

* more BUILD

* fix schema

* give BaseReader a __modify_schema__

* dont use ModelField

* remove validators

* Support schema in BaseReader for PydanticV2

---------

Co-authored-by: Logan Markewich <logan.markewich@live.com>
---
 .../llama_index/core/readers/base.py          |  22 +-
 .../llama-index-readers-s3/CHANGELOG.md       |   5 +
 .../readers/llama-index-readers-s3/README.md  |   2 +-
 .../llama_index/readers/s3/base.py            | 210 ++++++++----------
 .../llama-index-readers-s3/pyproject.toml     |  11 +-
 .../llama-index-readers-s3/requirements.txt   |   2 +-
 .../llama-index-readers-s3/tests/BUILD        |  10 +-
 .../tests/test_readers_s3.py                  | 110 ++++++++-
 8 files changed, 246 insertions(+), 126 deletions(-)

diff --git a/llama-index-core/llama_index/core/readers/base.py b/llama-index-core/llama_index/core/readers/base.py
index 75e333166c..d62d9af0ab 100644
--- a/llama-index-core/llama_index/core/readers/base.py
+++ b/llama-index-core/llama_index/core/readers/base.py
@@ -1,7 +1,14 @@
 """Base reader class."""
 
 from abc import ABC
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+)
 
 if TYPE_CHECKING:
     from llama_index.core.bridge.langchain import Document as LCDocument
@@ -27,6 +34,19 @@ class BaseReader(ABC):
         docs = self.load_data(**load_kwargs)
         return [d.to_langchain_format() for d in docs]
 
+    @classmethod
+    def __modify_schema__(cls, field_schema: Dict[str, Any], field: Optional[Any]):
+        field_schema.update({"title": cls.__name__})
+
+    @classmethod
+    def __get_pydantic_json_schema__(
+        cls, core_schema, handler
+    ):  # Needed for pydantic v2 to work
+        json_schema = handler(core_schema)
+        json_schema = handler.resolve_ref_schema(json_schema)
+        json_schema["title"] = cls.__name__
+        return json_schema
+
 
 class BasePydanticReader(BaseReader, BaseComponent):
     """Serialiable Data Loader with Pydantic."""
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md b/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md
index 36bff877ab..e1a472b4bb 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md
+++ b/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md
@@ -1,5 +1,10 @@
 # CHANGELOG
 
+## [0.1.4] - 2024-03-18
+
+- Refactor: Take advantage of `SimpleDirectoryReader` now supporting `fs` by using `s3fs` instead of downloading files to local disk.
+- Make the reader serializable by inheriting from `BasePydanticReader` instead of `BaseReader`.
+
 ## [0.1.2] - 2024-02-13
 
 - Add maintainers and keywords from library.json (llamahub)
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/README.md b/llama-index-integrations/readers/llama-index-readers-s3/README.md
index 96c3ed4d95..8b6a3cbaf9 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/README.md
+++ b/llama-index-integrations/readers/llama-index-readers-s3/README.md
@@ -2,7 +2,7 @@
 
 This loader parses any file stored on S3, or the entire Bucket (with an optional prefix filter) if no particular file is specified. When initializing `S3Reader`, you may pass in your [AWS Access Key](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html). If none are found, the loader assumes they are stored in `~/.aws/credentials`.
 
-All files are temporarily downloaded locally and subsequently parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)!
+All files are parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)!
 
 ## Usage
 
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py
index 47a525ca37..6d6d5a59ef 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py
+++ b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py
@@ -1,145 +1,121 @@
-"""S3 file and directory reader.
+"""
+S3 file and directory reader.
 
 A loader that fetches a file or iterates through a directory on AWS S3.
 
 """
 
-import os
-import shutil
-import tempfile
-from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Union
+import warnings
+from typing import Callable, Dict, List, Optional, Union
 
 from llama_index.core.readers import SimpleDirectoryReader
-from llama_index.core.readers.base import BaseReader
+from llama_index.core.readers.base import BaseReader, BasePydanticReader
 from llama_index.core.schema import Document
+from llama_index.core.bridge.pydantic import Field
+
+
+class S3Reader(BasePydanticReader):
+    """
+    General reader for any S3 file or directory.
+
+    If key is not set, the entire bucket (filtered by prefix) is parsed.
+
+    Args:
+    bucket (str): the name of your S3 bucket
+    key (Optional[str]): the name of the specific file. If none is provided,
+        this loader will iterate through the entire bucket.
+    prefix (Optional[str]): the prefix to filter by in the case that the loader
+        iterates through the entire bucket. Defaults to empty string.
+    recursive (bool): Whether to recursively search in subdirectories.
+        True by default.
+    file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file
+        extension to a BaseReader class that specifies how to convert that file
+        to text. See `SimpleDirectoryReader` for more details.
+    required_exts (Optional[List[str]]): List of required extensions.
+        Default is None.
+    num_files_limit (Optional[int]): Maximum number of files to read.
+        Default is None.
+    file_metadata (Optional[Callable[str, Dict]]): A function that takes
+        in a filename and returns a Dict of metadata for the Document.
+        Default is None.
+    aws_access_id (Optional[str]): provide AWS access key directly.
+    aws_access_secret (Optional[str]): provide AWS access key directly.
+    s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly.
+    """
+
+    is_remote: bool = True
+
+    bucket: str
+    key: Optional[str] = None
+    prefix: Optional[str] = ""
+    recursive: bool = True
+    file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = Field(
+        default=None, exclude=True
+    )
+    required_exts: Optional[List[str]] = None
+    filename_as_id: bool = True
+    num_files_limit: Optional[int] = None
+    file_metadata: Optional[Callable[[str], Dict]] = Field(default=None, exclude=True)
+    aws_access_id: Optional[str] = None
+    aws_access_secret: Optional[str] = None
+    aws_session_token: Optional[str] = None
+    s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com"
+    custom_reader_path: Optional[str] = None
+
+    @classmethod
+    def class_name(cls) -> str:
+        return "S3Reader"
+
+    def load_s3_files_as_docs(self, temp_dir=None) -> List[Document]:
+        """Load file(s) from S3."""
+        from s3fs import S3FileSystem
 
+        s3fs = S3FileSystem(
+            key=self.aws_access_id,
+            endpoint_url=self.s3_endpoint_url,
+            secret=self.aws_access_secret,
+            token=self.aws_session_token,
+        )
 
-class S3Reader(BaseReader):
-    """General reader for any S3 file or directory."""
-
-    def __init__(
-        self,
-        *args: Any,
-        bucket: str,
-        key: Optional[str] = None,
-        prefix: Optional[str] = "",
-        file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None,
-        required_exts: Optional[List[str]] = None,
-        filename_as_id: bool = True,
-        num_files_limit: Optional[int] = None,
-        file_metadata: Optional[Callable[[str], Dict]] = None,
-        aws_access_id: Optional[str] = None,
-        aws_access_secret: Optional[str] = None,
-        aws_session_token: Optional[str] = None,
-        s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com",
-        custom_reader_path: Optional[str] = None,
-        **kwargs: Any,
-    ) -> None:
-        """Initialize S3 bucket and key, along with credentials if needed.
-
-        If key is not set, the entire bucket (filtered by prefix) is parsed.
-
-        Args:
-        bucket (str): the name of your S3 bucket
-        key (Optional[str]): the name of the specific file. If none is provided,
-            this loader will iterate through the entire bucket.
-        prefix (Optional[str]): the prefix to filter by in the case that the loader
-            iterates through the entire bucket. Defaults to empty string.
-        file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file
-            extension to a BaseReader class that specifies how to convert that file
-            to text. See `SimpleDirectoryReader` for more details.
-        required_exts (Optional[List[str]]): List of required extensions.
-            Default is None.
-        num_files_limit (Optional[int]): Maximum number of files to read.
-            Default is None.
-        file_metadata (Optional[Callable[str, Dict]]): A function that takes
-            in a filename and returns a Dict of metadata for the Document.
-            Default is None.
-        aws_access_id (Optional[str]): provide AWS access key directly.
-        aws_access_secret (Optional[str]): provide AWS access key directly.
-        s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly.
-        """
-        super().__init__(*args, **kwargs)
-
-        self.bucket = bucket
-        self.key = key
-        self.prefix = prefix
-
-        self.file_extractor = file_extractor
-        self.required_exts = required_exts
-        self.filename_as_id = filename_as_id
-        self.num_files_limit = num_files_limit
-        self.file_metadata = file_metadata
-        self.custom_reader_path = custom_reader_path
-
-        self.aws_access_id = aws_access_id
-        self.aws_access_secret = aws_access_secret
-        self.aws_session_token = aws_session_token
-        self.s3_endpoint_url = s3_endpoint_url
-
-    def load_s3_files_as_docs(self, temp_dir) -> List[Document]:
-        """Load file(s) from S3."""
-        import boto3
-
-        s3 = boto3.resource("s3")
-        s3_client = boto3.client("s3")
-        if self.aws_access_id:
-            session = boto3.Session(
-                aws_access_key_id=self.aws_access_id,
-                aws_secret_access_key=self.aws_access_secret,
-                aws_session_token=self.aws_session_token,
-            )
-            s3 = session.resource("s3", endpoint_url=self.s3_endpoint_url)
-            s3_client = session.client("s3", endpoint_url=self.s3_endpoint_url)
+        input_dir = self.bucket
+        input_files = None
 
         if self.key:
-            filename = Path(self.key).name
-            suffix = Path(self.key).suffix
-            filepath = f"{temp_dir}/{filename}"
-            s3_client.download_file(self.bucket, self.key, filepath)
-        else:
-            bucket = s3.Bucket(self.bucket)
-            for i, obj in enumerate(bucket.objects.filter(Prefix=self.prefix)):
-                if self.num_files_limit is not None and i > self.num_files_limit:
-                    break
-                filename = Path(obj.key).name
-                suffix = Path(obj.key).suffix
-
-                is_dir = obj.key.endswith("/")  # skip folders
-                is_bad_ext = (
-                    self.required_exts is not None
-                    and suffix not in self.required_exts  # skip other extensions
-                )
-
-                if is_dir or is_bad_ext:
-                    continue
-
-                filepath = f"{temp_dir}/{filename}"
-                s3_client.download_file(self.bucket, obj.key, filepath)
+            input_files = [f"{self.bucket}/{self.key}"]
+        elif self.prefix:
+            input_dir = f"{input_dir}/{self.prefix}"
 
         loader = SimpleDirectoryReader(
-            temp_dir,
+            input_dir=input_dir,
+            input_files=input_files,
             file_extractor=self.file_extractor,
             required_exts=self.required_exts,
             filename_as_id=self.filename_as_id,
             num_files_limit=self.num_files_limit,
             file_metadata=self.file_metadata,
+            recursive=self.recursive,
+            fs=s3fs,
         )
 
         return loader.load_data()
 
     def load_data(self, custom_temp_subdir: str = None) -> List[Document]:
-        """Decide which directory to load files in - randomly generated directories under /tmp or a custom subdirectory under /tmp."""
-        if custom_temp_subdir is None:
-            with tempfile.TemporaryDirectory() as temp_dir:
-                documents = self.load_s3_files_as_docs(temp_dir)
-        else:
-            temp_dir = os.path.join("/tmp", custom_temp_subdir)
-            os.makedirs(temp_dir, exist_ok=True)
-            documents = self.load_s3_files_as_docs(temp_dir)
-            shutil.rmtree(temp_dir)
+        """
+        Load the file(s) from S3.
+
+        Args:
+            custom_temp_subdir (str, optional): This parameter is deprecated and unused. Defaults to None.
+
+        Returns:
+            List[Document]: A list of documents loaded from S3.
+        """
+        if custom_temp_subdir is not None:
+            warnings.warn(
+                "The `custom_temp_subdir` parameter is deprecated and unused. Please remove it from your code.",
+                DeprecationWarning,
+            )
 
+        documents = self.load_s3_files_as_docs()
         for doc in documents:
             doc.id_ = self.s3_endpoint_url + "_" + doc.id_
 
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml
index d9b00d2c91..1594d6891f 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml
+++ b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml
@@ -29,12 +29,13 @@ license = "MIT"
 maintainers = ["thejessezhang"]
 name = "llama-index-readers-s3"
 readme = "README.md"
-version = "0.1.3"
+version = "0.1.4"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-boto3 = "^1.34.29"
+llama-index-readers-file = "^0.1.11"
+s3fs = "^2024.3.0"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
@@ -44,7 +45,7 @@ pre-commit = "3.2.0"
 pylint = "2.15.10"
 pytest = "7.2.1"
 pytest-mock = "3.11.1"
-ruff = "0.0.292"
+ruff = "^0.2.1"
 tree-sitter-languages = "^1.8.0"
 types-Deprecated = ">=0.1.0"
 types-PyYAML = "^6.0.12.12"
@@ -61,5 +62,9 @@ version = "<=23.9.1,>=23.7.0"
 extras = ["toml"]
 version = ">=v2.2.6"
 
+[tool.poetry.group.dev.dependencies.moto]
+extras = ["server"]
+version = "^5.0.3"
+
 [[tool.poetry.packages]]
 include = "llama_index/"
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt b/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt
index 30ddf823b8..a51391a948 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt
+++ b/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt
@@ -1 +1 @@
-boto3
+s3fs
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD b/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD
index dabf212d7e..f8c6a7f270 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD
+++ b/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD
@@ -1 +1,9 @@
-python_tests()
+python_tests(
+    name="tests",
+    dependencies=["llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/__init__.py"]
+)
+
+python_requirement(
+    name="s3fs",
+    requirements=["s3fs>=2024.3.0"],
+)
diff --git a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py
index f5fbe5599a..27f887eb0e 100644
--- a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py
+++ b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py
@@ -1,7 +1,113 @@
-from llama_index.core.readers.base import BaseReader
+from llama_index.core.readers.base import BasePydanticReader
 from llama_index.readers.s3 import S3Reader
+from moto.server import ThreadedMotoServer
+import pytest
+import os
+import requests
+from s3fs import S3FileSystem
+
+test_bucket = "test"
+files = [
+    "test/test.txt",
+    "test/subdir/test2.txt",
+    "test/subdir2/test3.txt",
+]
+ip_address = "127.0.0.1"
+port = 5555
+endpoint_url = f"http://{ip_address}:{port}"
+
+
+@pytest.fixture(scope="module")
+def s3_base():
+    # We create this module-level fixture to ensure that the server is only started once
+    s3_server = ThreadedMotoServer(ip_address=ip_address, port=port)
+    s3_server.start()
+    if "AWS_ACCESS_KEY_ID" not in os.environ:
+        os.environ["AWS_ACCESS_KEY_ID"] = "test"
+    if "AWS_SECRET_ACCESS_KEY" not in os.environ:
+        os.environ["AWS_SECRET_ACCESS_KEY"] = "test"
+    yield
+    s3_server.stop()
+
+
+@pytest.fixture()
+def init_s3_files(s3_base):
+    requests.post(f"{endpoint_url}/moto-api/reset")
+    s3fs = S3FileSystem(
+        endpoint_url=endpoint_url,
+    )
+    s3fs.mkdir(test_bucket)
+    s3fs.mkdir(f"{test_bucket}/subdir")
+    s3fs.mkdir(f"{test_bucket}/subdir2")
+    for file in files:
+        s3fs.touch(file)
 
 
 def test_class():
     names_of_base_classes = [b.__name__ for b in S3Reader.__mro__]
-    assert BaseReader.__name__ in names_of_base_classes
+    assert BasePydanticReader.__name__ in names_of_base_classes
+
+
+def test_load_all_files(init_s3_files):
+    reader = S3Reader(
+        bucket=test_bucket,
+        s3_endpoint_url=endpoint_url,
+    )
+    documents = reader.load_data()
+    assert len(documents) == len(files)
+
+
+def test_load_single_file(init_s3_files):
+    reader = S3Reader(
+        bucket=test_bucket,
+        key="test.txt",
+        s3_endpoint_url=endpoint_url,
+    )
+    documents = reader.load_data()
+    assert len(documents) == 1
+    assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt"
+
+
+def test_load_with_prefix(init_s3_files):
+    reader = S3Reader(
+        bucket=test_bucket,
+        prefix="subdir",
+        s3_endpoint_url=endpoint_url,
+    )
+    documents = reader.load_data()
+    assert len(documents) == 1
+    assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir/test2.txt"
+
+    reader.prefix = "subdir2"
+    documents = reader.load_data()
+    assert len(documents) == 1
+    assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir2/test3.txt"
+
+
+def test_load_not_recursive(init_s3_files):
+    reader = S3Reader(
+        bucket=test_bucket,
+        recursive=False,
+        s3_endpoint_url=endpoint_url,
+    )
+    documents = reader.load_data()
+    assert len(documents) == 1
+    assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt"
+
+
+def test_serialize():
+    reader = S3Reader(
+        bucket=test_bucket,
+        s3_endpoint_url=endpoint_url,
+    )
+
+    schema = reader.schema()
+    assert schema is not None
+    assert len(schema) > 0
+    assert "bucket" in schema["properties"]
+
+    json = reader.json(exclude_unset=True)
+
+    new_reader = S3Reader.parse_raw(json)
+    assert new_reader.bucket == reader.bucket
+    assert new_reader.s3_endpoint_url == reader.s3_endpoint_url
-- 
GitLab