From 181b32291744542bc8c55acf2edea73b4814b69d Mon Sep 17 00:00:00 2001 From: Javier Torres <javierandrestorresreyes@gmail.com> Date: Wed, 20 Mar 2024 16:47:34 -0500 Subject: [PATCH] Use S3FS in S3Reader (#12061) * use s3fs * when key=none * use s3fs * changelog + version bump * tests + serialization * changelog * pants update * module mappings * add files reader to BUILD * more BUILD * fix schema * give BaseReader a __modify_schema__ * dont use ModelField * remove validators * Support schema in BaseReader for PydanticV2 --------- Co-authored-by: Logan Markewich <logan.markewich@live.com> --- .../llama_index/core/readers/base.py | 22 +- .../llama-index-readers-s3/CHANGELOG.md | 5 + .../readers/llama-index-readers-s3/README.md | 2 +- .../llama_index/readers/s3/base.py | 210 ++++++++---------- .../llama-index-readers-s3/pyproject.toml | 11 +- .../llama-index-readers-s3/requirements.txt | 2 +- .../llama-index-readers-s3/tests/BUILD | 10 +- .../tests/test_readers_s3.py | 110 ++++++++- 8 files changed, 246 insertions(+), 126 deletions(-) diff --git a/llama-index-core/llama_index/core/readers/base.py b/llama-index-core/llama_index/core/readers/base.py index 75e333166c..d62d9af0ab 100644 --- a/llama-index-core/llama_index/core/readers/base.py +++ b/llama-index-core/llama_index/core/readers/base.py @@ -1,7 +1,14 @@ """Base reader class.""" from abc import ABC -from typing import TYPE_CHECKING, Any, Dict, Iterable, List +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, +) if TYPE_CHECKING: from llama_index.core.bridge.langchain import Document as LCDocument @@ -27,6 +34,19 @@ class BaseReader(ABC): docs = self.load_data(**load_kwargs) return [d.to_langchain_format() for d in docs] + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any], field: Optional[Any]): + field_schema.update({"title": cls.__name__}) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema, handler + ): # Needed for pydantic v2 to work + json_schema = handler(core_schema) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema["title"] = cls.__name__ + return json_schema + class BasePydanticReader(BaseReader, BaseComponent): """Serialiable Data Loader with Pydantic.""" diff --git a/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md b/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md index 36bff877ab..e1a472b4bb 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md +++ b/llama-index-integrations/readers/llama-index-readers-s3/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## [0.1.4] - 2024-03-18 + +- Refactor: Take advantage of `SimpleDirectoryReader` now supporting `fs` by using `s3fs` instead of downloading files to local disk. +- Make the reader serializable by inheriting from `BasePydanticReader` instead of `BaseReader`. + ## [0.1.2] - 2024-02-13 - Add maintainers and keywords from library.json (llamahub) diff --git a/llama-index-integrations/readers/llama-index-readers-s3/README.md b/llama-index-integrations/readers/llama-index-readers-s3/README.md index 96c3ed4d95..8b6a3cbaf9 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/README.md +++ b/llama-index-integrations/readers/llama-index-readers-s3/README.md @@ -2,7 +2,7 @@ This loader parses any file stored on S3, or the entire Bucket (with an optional prefix filter) if no particular file is specified. When initializing `S3Reader`, you may pass in your [AWS Access Key](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html). If none are found, the loader assumes they are stored in `~/.aws/credentials`. -All files are temporarily downloaded locally and subsequently parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)! +All files are parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)! ## Usage diff --git a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py index 47a525ca37..6d6d5a59ef 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py +++ b/llama-index-integrations/readers/llama-index-readers-s3/llama_index/readers/s3/base.py @@ -1,145 +1,121 @@ -"""S3 file and directory reader. +""" +S3 file and directory reader. A loader that fetches a file or iterates through a directory on AWS S3. """ -import os -import shutil -import tempfile -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +import warnings +from typing import Callable, Dict, List, Optional, Union from llama_index.core.readers import SimpleDirectoryReader -from llama_index.core.readers.base import BaseReader +from llama_index.core.readers.base import BaseReader, BasePydanticReader from llama_index.core.schema import Document +from llama_index.core.bridge.pydantic import Field + + +class S3Reader(BasePydanticReader): + """ + General reader for any S3 file or directory. + + If key is not set, the entire bucket (filtered by prefix) is parsed. + + Args: + bucket (str): the name of your S3 bucket + key (Optional[str]): the name of the specific file. If none is provided, + this loader will iterate through the entire bucket. + prefix (Optional[str]): the prefix to filter by in the case that the loader + iterates through the entire bucket. Defaults to empty string. + recursive (bool): Whether to recursively search in subdirectories. + True by default. + file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file + extension to a BaseReader class that specifies how to convert that file + to text. See `SimpleDirectoryReader` for more details. + required_exts (Optional[List[str]]): List of required extensions. + Default is None. + num_files_limit (Optional[int]): Maximum number of files to read. + Default is None. + file_metadata (Optional[Callable[str, Dict]]): A function that takes + in a filename and returns a Dict of metadata for the Document. + Default is None. + aws_access_id (Optional[str]): provide AWS access key directly. + aws_access_secret (Optional[str]): provide AWS access key directly. + s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly. + """ + + is_remote: bool = True + + bucket: str + key: Optional[str] = None + prefix: Optional[str] = "" + recursive: bool = True + file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = Field( + default=None, exclude=True + ) + required_exts: Optional[List[str]] = None + filename_as_id: bool = True + num_files_limit: Optional[int] = None + file_metadata: Optional[Callable[[str], Dict]] = Field(default=None, exclude=True) + aws_access_id: Optional[str] = None + aws_access_secret: Optional[str] = None + aws_session_token: Optional[str] = None + s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com" + custom_reader_path: Optional[str] = None + + @classmethod + def class_name(cls) -> str: + return "S3Reader" + + def load_s3_files_as_docs(self, temp_dir=None) -> List[Document]: + """Load file(s) from S3.""" + from s3fs import S3FileSystem + s3fs = S3FileSystem( + key=self.aws_access_id, + endpoint_url=self.s3_endpoint_url, + secret=self.aws_access_secret, + token=self.aws_session_token, + ) -class S3Reader(BaseReader): - """General reader for any S3 file or directory.""" - - def __init__( - self, - *args: Any, - bucket: str, - key: Optional[str] = None, - prefix: Optional[str] = "", - file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None, - required_exts: Optional[List[str]] = None, - filename_as_id: bool = True, - num_files_limit: Optional[int] = None, - file_metadata: Optional[Callable[[str], Dict]] = None, - aws_access_id: Optional[str] = None, - aws_access_secret: Optional[str] = None, - aws_session_token: Optional[str] = None, - s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com", - custom_reader_path: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Initialize S3 bucket and key, along with credentials if needed. - - If key is not set, the entire bucket (filtered by prefix) is parsed. - - Args: - bucket (str): the name of your S3 bucket - key (Optional[str]): the name of the specific file. If none is provided, - this loader will iterate through the entire bucket. - prefix (Optional[str]): the prefix to filter by in the case that the loader - iterates through the entire bucket. Defaults to empty string. - file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file - extension to a BaseReader class that specifies how to convert that file - to text. See `SimpleDirectoryReader` for more details. - required_exts (Optional[List[str]]): List of required extensions. - Default is None. - num_files_limit (Optional[int]): Maximum number of files to read. - Default is None. - file_metadata (Optional[Callable[str, Dict]]): A function that takes - in a filename and returns a Dict of metadata for the Document. - Default is None. - aws_access_id (Optional[str]): provide AWS access key directly. - aws_access_secret (Optional[str]): provide AWS access key directly. - s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly. - """ - super().__init__(*args, **kwargs) - - self.bucket = bucket - self.key = key - self.prefix = prefix - - self.file_extractor = file_extractor - self.required_exts = required_exts - self.filename_as_id = filename_as_id - self.num_files_limit = num_files_limit - self.file_metadata = file_metadata - self.custom_reader_path = custom_reader_path - - self.aws_access_id = aws_access_id - self.aws_access_secret = aws_access_secret - self.aws_session_token = aws_session_token - self.s3_endpoint_url = s3_endpoint_url - - def load_s3_files_as_docs(self, temp_dir) -> List[Document]: - """Load file(s) from S3.""" - import boto3 - - s3 = boto3.resource("s3") - s3_client = boto3.client("s3") - if self.aws_access_id: - session = boto3.Session( - aws_access_key_id=self.aws_access_id, - aws_secret_access_key=self.aws_access_secret, - aws_session_token=self.aws_session_token, - ) - s3 = session.resource("s3", endpoint_url=self.s3_endpoint_url) - s3_client = session.client("s3", endpoint_url=self.s3_endpoint_url) + input_dir = self.bucket + input_files = None if self.key: - filename = Path(self.key).name - suffix = Path(self.key).suffix - filepath = f"{temp_dir}/{filename}" - s3_client.download_file(self.bucket, self.key, filepath) - else: - bucket = s3.Bucket(self.bucket) - for i, obj in enumerate(bucket.objects.filter(Prefix=self.prefix)): - if self.num_files_limit is not None and i > self.num_files_limit: - break - filename = Path(obj.key).name - suffix = Path(obj.key).suffix - - is_dir = obj.key.endswith("/") # skip folders - is_bad_ext = ( - self.required_exts is not None - and suffix not in self.required_exts # skip other extensions - ) - - if is_dir or is_bad_ext: - continue - - filepath = f"{temp_dir}/{filename}" - s3_client.download_file(self.bucket, obj.key, filepath) + input_files = [f"{self.bucket}/{self.key}"] + elif self.prefix: + input_dir = f"{input_dir}/{self.prefix}" loader = SimpleDirectoryReader( - temp_dir, + input_dir=input_dir, + input_files=input_files, file_extractor=self.file_extractor, required_exts=self.required_exts, filename_as_id=self.filename_as_id, num_files_limit=self.num_files_limit, file_metadata=self.file_metadata, + recursive=self.recursive, + fs=s3fs, ) return loader.load_data() def load_data(self, custom_temp_subdir: str = None) -> List[Document]: - """Decide which directory to load files in - randomly generated directories under /tmp or a custom subdirectory under /tmp.""" - if custom_temp_subdir is None: - with tempfile.TemporaryDirectory() as temp_dir: - documents = self.load_s3_files_as_docs(temp_dir) - else: - temp_dir = os.path.join("/tmp", custom_temp_subdir) - os.makedirs(temp_dir, exist_ok=True) - documents = self.load_s3_files_as_docs(temp_dir) - shutil.rmtree(temp_dir) + """ + Load the file(s) from S3. + + Args: + custom_temp_subdir (str, optional): This parameter is deprecated and unused. Defaults to None. + + Returns: + List[Document]: A list of documents loaded from S3. + """ + if custom_temp_subdir is not None: + warnings.warn( + "The `custom_temp_subdir` parameter is deprecated and unused. Please remove it from your code.", + DeprecationWarning, + ) + documents = self.load_s3_files_as_docs() for doc in documents: doc.id_ = self.s3_endpoint_url + "_" + doc.id_ diff --git a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml index d9b00d2c91..1594d6891f 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-s3/pyproject.toml @@ -29,12 +29,13 @@ license = "MIT" maintainers = ["thejessezhang"] name = "llama-index-readers-s3" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -boto3 = "^1.34.29" +llama-index-readers-file = "^0.1.11" +s3fs = "^2024.3.0" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" @@ -44,7 +45,7 @@ pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" pytest-mock = "3.11.1" -ruff = "0.0.292" +ruff = "^0.2.1" tree-sitter-languages = "^1.8.0" types-Deprecated = ">=0.1.0" types-PyYAML = "^6.0.12.12" @@ -61,5 +62,9 @@ version = "<=23.9.1,>=23.7.0" extras = ["toml"] version = ">=v2.2.6" +[tool.poetry.group.dev.dependencies.moto] +extras = ["server"] +version = "^5.0.3" + [[tool.poetry.packages]] include = "llama_index/" diff --git a/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt b/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt index 30ddf823b8..a51391a948 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt +++ b/llama-index-integrations/readers/llama-index-readers-s3/requirements.txt @@ -1 +1 @@ -boto3 +s3fs diff --git a/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD b/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD index dabf212d7e..f8c6a7f270 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD +++ b/llama-index-integrations/readers/llama-index-readers-s3/tests/BUILD @@ -1 +1,9 @@ -python_tests() +python_tests( + name="tests", + dependencies=["llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/__init__.py"] +) + +python_requirement( + name="s3fs", + requirements=["s3fs>=2024.3.0"], +) diff --git a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py index f5fbe5599a..27f887eb0e 100644 --- a/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py +++ b/llama-index-integrations/readers/llama-index-readers-s3/tests/test_readers_s3.py @@ -1,7 +1,113 @@ -from llama_index.core.readers.base import BaseReader +from llama_index.core.readers.base import BasePydanticReader from llama_index.readers.s3 import S3Reader +from moto.server import ThreadedMotoServer +import pytest +import os +import requests +from s3fs import S3FileSystem + +test_bucket = "test" +files = [ + "test/test.txt", + "test/subdir/test2.txt", + "test/subdir2/test3.txt", +] +ip_address = "127.0.0.1" +port = 5555 +endpoint_url = f"http://{ip_address}:{port}" + + +@pytest.fixture(scope="module") +def s3_base(): + # We create this module-level fixture to ensure that the server is only started once + s3_server = ThreadedMotoServer(ip_address=ip_address, port=port) + s3_server.start() + if "AWS_ACCESS_KEY_ID" not in os.environ: + os.environ["AWS_ACCESS_KEY_ID"] = "test" + if "AWS_SECRET_ACCESS_KEY" not in os.environ: + os.environ["AWS_SECRET_ACCESS_KEY"] = "test" + yield + s3_server.stop() + + +@pytest.fixture() +def init_s3_files(s3_base): + requests.post(f"{endpoint_url}/moto-api/reset") + s3fs = S3FileSystem( + endpoint_url=endpoint_url, + ) + s3fs.mkdir(test_bucket) + s3fs.mkdir(f"{test_bucket}/subdir") + s3fs.mkdir(f"{test_bucket}/subdir2") + for file in files: + s3fs.touch(file) def test_class(): names_of_base_classes = [b.__name__ for b in S3Reader.__mro__] - assert BaseReader.__name__ in names_of_base_classes + assert BasePydanticReader.__name__ in names_of_base_classes + + +def test_load_all_files(init_s3_files): + reader = S3Reader( + bucket=test_bucket, + s3_endpoint_url=endpoint_url, + ) + documents = reader.load_data() + assert len(documents) == len(files) + + +def test_load_single_file(init_s3_files): + reader = S3Reader( + bucket=test_bucket, + key="test.txt", + s3_endpoint_url=endpoint_url, + ) + documents = reader.load_data() + assert len(documents) == 1 + assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt" + + +def test_load_with_prefix(init_s3_files): + reader = S3Reader( + bucket=test_bucket, + prefix="subdir", + s3_endpoint_url=endpoint_url, + ) + documents = reader.load_data() + assert len(documents) == 1 + assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir/test2.txt" + + reader.prefix = "subdir2" + documents = reader.load_data() + assert len(documents) == 1 + assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir2/test3.txt" + + +def test_load_not_recursive(init_s3_files): + reader = S3Reader( + bucket=test_bucket, + recursive=False, + s3_endpoint_url=endpoint_url, + ) + documents = reader.load_data() + assert len(documents) == 1 + assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt" + + +def test_serialize(): + reader = S3Reader( + bucket=test_bucket, + s3_endpoint_url=endpoint_url, + ) + + schema = reader.schema() + assert schema is not None + assert len(schema) > 0 + assert "bucket" in schema["properties"] + + json = reader.json(exclude_unset=True) + + new_reader = S3Reader.parse_raw(json) + assert new_reader.bucket == reader.bucket + assert new_reader.s3_endpoint_url == reader.s3_endpoint_url -- GitLab