Skip to content
Snippets Groups Projects
Unverified Commit 181b3229 authored by Javier Torres's avatar Javier Torres Committed by GitHub
Browse files

Use S3FS in S3Reader (#12061)


* use s3fs

* when key=none

* use s3fs

* changelog + version bump

* tests + serialization

* changelog

* pants update

* module mappings

* add files reader to BUILD

* more BUILD

* fix schema

* give BaseReader a __modify_schema__

* dont use ModelField

* remove validators

* Support schema in BaseReader for PydanticV2

---------

Co-authored-by: default avatarLogan Markewich <logan.markewich@live.com>
parent 0ade3860
No related branches found
No related tags found
No related merge requests found
"""Base reader class."""
from abc import ABC
from typing import TYPE_CHECKING, Any, Dict, Iterable, List
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
)
if TYPE_CHECKING:
from llama_index.core.bridge.langchain import Document as LCDocument
......@@ -27,6 +34,19 @@ class BaseReader(ABC):
docs = self.load_data(**load_kwargs)
return [d.to_langchain_format() for d in docs]
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any], field: Optional[Any]):
field_schema.update({"title": cls.__name__})
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema, handler
): # Needed for pydantic v2 to work
json_schema = handler(core_schema)
json_schema = handler.resolve_ref_schema(json_schema)
json_schema["title"] = cls.__name__
return json_schema
class BasePydanticReader(BaseReader, BaseComponent):
"""Serialiable Data Loader with Pydantic."""
......
# CHANGELOG
## [0.1.4] - 2024-03-18
- Refactor: Take advantage of `SimpleDirectoryReader` now supporting `fs` by using `s3fs` instead of downloading files to local disk.
- Make the reader serializable by inheriting from `BasePydanticReader` instead of `BaseReader`.
## [0.1.2] - 2024-02-13
- Add maintainers and keywords from library.json (llamahub)
......@@ -2,7 +2,7 @@
This loader parses any file stored on S3, or the entire Bucket (with an optional prefix filter) if no particular file is specified. When initializing `S3Reader`, you may pass in your [AWS Access Key](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html). If none are found, the loader assumes they are stored in `~/.aws/credentials`.
All files are temporarily downloaded locally and subsequently parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)!
All files are parsed with `SimpleDirectoryReader`. Hence, you may also specify a custom `file_extractor`, relying on any of the loaders in this library (or your own)!
## Usage
......
"""S3 file and directory reader.
"""
S3 file and directory reader.
A loader that fetches a file or iterates through a directory on AWS S3.
"""
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import warnings
from typing import Callable, Dict, List, Optional, Union
from llama_index.core.readers import SimpleDirectoryReader
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.base import BaseReader, BasePydanticReader
from llama_index.core.schema import Document
from llama_index.core.bridge.pydantic import Field
class S3Reader(BasePydanticReader):
"""
General reader for any S3 file or directory.
If key is not set, the entire bucket (filtered by prefix) is parsed.
Args:
bucket (str): the name of your S3 bucket
key (Optional[str]): the name of the specific file. If none is provided,
this loader will iterate through the entire bucket.
prefix (Optional[str]): the prefix to filter by in the case that the loader
iterates through the entire bucket. Defaults to empty string.
recursive (bool): Whether to recursively search in subdirectories.
True by default.
file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file
extension to a BaseReader class that specifies how to convert that file
to text. See `SimpleDirectoryReader` for more details.
required_exts (Optional[List[str]]): List of required extensions.
Default is None.
num_files_limit (Optional[int]): Maximum number of files to read.
Default is None.
file_metadata (Optional[Callable[str, Dict]]): A function that takes
in a filename and returns a Dict of metadata for the Document.
Default is None.
aws_access_id (Optional[str]): provide AWS access key directly.
aws_access_secret (Optional[str]): provide AWS access key directly.
s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly.
"""
is_remote: bool = True
bucket: str
key: Optional[str] = None
prefix: Optional[str] = ""
recursive: bool = True
file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = Field(
default=None, exclude=True
)
required_exts: Optional[List[str]] = None
filename_as_id: bool = True
num_files_limit: Optional[int] = None
file_metadata: Optional[Callable[[str], Dict]] = Field(default=None, exclude=True)
aws_access_id: Optional[str] = None
aws_access_secret: Optional[str] = None
aws_session_token: Optional[str] = None
s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com"
custom_reader_path: Optional[str] = None
@classmethod
def class_name(cls) -> str:
return "S3Reader"
def load_s3_files_as_docs(self, temp_dir=None) -> List[Document]:
"""Load file(s) from S3."""
from s3fs import S3FileSystem
s3fs = S3FileSystem(
key=self.aws_access_id,
endpoint_url=self.s3_endpoint_url,
secret=self.aws_access_secret,
token=self.aws_session_token,
)
class S3Reader(BaseReader):
"""General reader for any S3 file or directory."""
def __init__(
self,
*args: Any,
bucket: str,
key: Optional[str] = None,
prefix: Optional[str] = "",
file_extractor: Optional[Dict[str, Union[str, BaseReader]]] = None,
required_exts: Optional[List[str]] = None,
filename_as_id: bool = True,
num_files_limit: Optional[int] = None,
file_metadata: Optional[Callable[[str], Dict]] = None,
aws_access_id: Optional[str] = None,
aws_access_secret: Optional[str] = None,
aws_session_token: Optional[str] = None,
s3_endpoint_url: Optional[str] = "https://s3.amazonaws.com",
custom_reader_path: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize S3 bucket and key, along with credentials if needed.
If key is not set, the entire bucket (filtered by prefix) is parsed.
Args:
bucket (str): the name of your S3 bucket
key (Optional[str]): the name of the specific file. If none is provided,
this loader will iterate through the entire bucket.
prefix (Optional[str]): the prefix to filter by in the case that the loader
iterates through the entire bucket. Defaults to empty string.
file_extractor (Optional[Dict[str, BaseReader]]): A mapping of file
extension to a BaseReader class that specifies how to convert that file
to text. See `SimpleDirectoryReader` for more details.
required_exts (Optional[List[str]]): List of required extensions.
Default is None.
num_files_limit (Optional[int]): Maximum number of files to read.
Default is None.
file_metadata (Optional[Callable[str, Dict]]): A function that takes
in a filename and returns a Dict of metadata for the Document.
Default is None.
aws_access_id (Optional[str]): provide AWS access key directly.
aws_access_secret (Optional[str]): provide AWS access key directly.
s3_endpoint_url (Optional[str]): provide S3 endpoint URL directly.
"""
super().__init__(*args, **kwargs)
self.bucket = bucket
self.key = key
self.prefix = prefix
self.file_extractor = file_extractor
self.required_exts = required_exts
self.filename_as_id = filename_as_id
self.num_files_limit = num_files_limit
self.file_metadata = file_metadata
self.custom_reader_path = custom_reader_path
self.aws_access_id = aws_access_id
self.aws_access_secret = aws_access_secret
self.aws_session_token = aws_session_token
self.s3_endpoint_url = s3_endpoint_url
def load_s3_files_as_docs(self, temp_dir) -> List[Document]:
"""Load file(s) from S3."""
import boto3
s3 = boto3.resource("s3")
s3_client = boto3.client("s3")
if self.aws_access_id:
session = boto3.Session(
aws_access_key_id=self.aws_access_id,
aws_secret_access_key=self.aws_access_secret,
aws_session_token=self.aws_session_token,
)
s3 = session.resource("s3", endpoint_url=self.s3_endpoint_url)
s3_client = session.client("s3", endpoint_url=self.s3_endpoint_url)
input_dir = self.bucket
input_files = None
if self.key:
filename = Path(self.key).name
suffix = Path(self.key).suffix
filepath = f"{temp_dir}/{filename}"
s3_client.download_file(self.bucket, self.key, filepath)
else:
bucket = s3.Bucket(self.bucket)
for i, obj in enumerate(bucket.objects.filter(Prefix=self.prefix)):
if self.num_files_limit is not None and i > self.num_files_limit:
break
filename = Path(obj.key).name
suffix = Path(obj.key).suffix
is_dir = obj.key.endswith("/") # skip folders
is_bad_ext = (
self.required_exts is not None
and suffix not in self.required_exts # skip other extensions
)
if is_dir or is_bad_ext:
continue
filepath = f"{temp_dir}/{filename}"
s3_client.download_file(self.bucket, obj.key, filepath)
input_files = [f"{self.bucket}/{self.key}"]
elif self.prefix:
input_dir = f"{input_dir}/{self.prefix}"
loader = SimpleDirectoryReader(
temp_dir,
input_dir=input_dir,
input_files=input_files,
file_extractor=self.file_extractor,
required_exts=self.required_exts,
filename_as_id=self.filename_as_id,
num_files_limit=self.num_files_limit,
file_metadata=self.file_metadata,
recursive=self.recursive,
fs=s3fs,
)
return loader.load_data()
def load_data(self, custom_temp_subdir: str = None) -> List[Document]:
"""Decide which directory to load files in - randomly generated directories under /tmp or a custom subdirectory under /tmp."""
if custom_temp_subdir is None:
with tempfile.TemporaryDirectory() as temp_dir:
documents = self.load_s3_files_as_docs(temp_dir)
else:
temp_dir = os.path.join("/tmp", custom_temp_subdir)
os.makedirs(temp_dir, exist_ok=True)
documents = self.load_s3_files_as_docs(temp_dir)
shutil.rmtree(temp_dir)
"""
Load the file(s) from S3.
Args:
custom_temp_subdir (str, optional): This parameter is deprecated and unused. Defaults to None.
Returns:
List[Document]: A list of documents loaded from S3.
"""
if custom_temp_subdir is not None:
warnings.warn(
"The `custom_temp_subdir` parameter is deprecated and unused. Please remove it from your code.",
DeprecationWarning,
)
documents = self.load_s3_files_as_docs()
for doc in documents:
doc.id_ = self.s3_endpoint_url + "_" + doc.id_
......
......@@ -29,12 +29,13 @@ license = "MIT"
maintainers = ["thejessezhang"]
name = "llama-index-readers-s3"
readme = "README.md"
version = "0.1.3"
version = "0.1.4"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
boto3 = "^1.34.29"
llama-index-readers-file = "^0.1.11"
s3fs = "^2024.3.0"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
......@@ -44,7 +45,7 @@ pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
ruff = "^0.2.1"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
......@@ -61,5 +62,9 @@ version = "<=23.9.1,>=23.7.0"
extras = ["toml"]
version = ">=v2.2.6"
[tool.poetry.group.dev.dependencies.moto]
extras = ["server"]
version = "^5.0.3"
[[tool.poetry.packages]]
include = "llama_index/"
python_tests()
python_tests(
name="tests",
dependencies=["llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/__init__.py"]
)
python_requirement(
name="s3fs",
requirements=["s3fs>=2024.3.0"],
)
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.base import BasePydanticReader
from llama_index.readers.s3 import S3Reader
from moto.server import ThreadedMotoServer
import pytest
import os
import requests
from s3fs import S3FileSystem
test_bucket = "test"
files = [
"test/test.txt",
"test/subdir/test2.txt",
"test/subdir2/test3.txt",
]
ip_address = "127.0.0.1"
port = 5555
endpoint_url = f"http://{ip_address}:{port}"
@pytest.fixture(scope="module")
def s3_base():
# We create this module-level fixture to ensure that the server is only started once
s3_server = ThreadedMotoServer(ip_address=ip_address, port=port)
s3_server.start()
if "AWS_ACCESS_KEY_ID" not in os.environ:
os.environ["AWS_ACCESS_KEY_ID"] = "test"
if "AWS_SECRET_ACCESS_KEY" not in os.environ:
os.environ["AWS_SECRET_ACCESS_KEY"] = "test"
yield
s3_server.stop()
@pytest.fixture()
def init_s3_files(s3_base):
requests.post(f"{endpoint_url}/moto-api/reset")
s3fs = S3FileSystem(
endpoint_url=endpoint_url,
)
s3fs.mkdir(test_bucket)
s3fs.mkdir(f"{test_bucket}/subdir")
s3fs.mkdir(f"{test_bucket}/subdir2")
for file in files:
s3fs.touch(file)
def test_class():
names_of_base_classes = [b.__name__ for b in S3Reader.__mro__]
assert BaseReader.__name__ in names_of_base_classes
assert BasePydanticReader.__name__ in names_of_base_classes
def test_load_all_files(init_s3_files):
reader = S3Reader(
bucket=test_bucket,
s3_endpoint_url=endpoint_url,
)
documents = reader.load_data()
assert len(documents) == len(files)
def test_load_single_file(init_s3_files):
reader = S3Reader(
bucket=test_bucket,
key="test.txt",
s3_endpoint_url=endpoint_url,
)
documents = reader.load_data()
assert len(documents) == 1
assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt"
def test_load_with_prefix(init_s3_files):
reader = S3Reader(
bucket=test_bucket,
prefix="subdir",
s3_endpoint_url=endpoint_url,
)
documents = reader.load_data()
assert len(documents) == 1
assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir/test2.txt"
reader.prefix = "subdir2"
documents = reader.load_data()
assert len(documents) == 1
assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/subdir2/test3.txt"
def test_load_not_recursive(init_s3_files):
reader = S3Reader(
bucket=test_bucket,
recursive=False,
s3_endpoint_url=endpoint_url,
)
documents = reader.load_data()
assert len(documents) == 1
assert documents[0].id_ == f"{endpoint_url}_{test_bucket}/test.txt"
def test_serialize():
reader = S3Reader(
bucket=test_bucket,
s3_endpoint_url=endpoint_url,
)
schema = reader.schema()
assert schema is not None
assert len(schema) > 0
assert "bucket" in schema["properties"]
json = reader.json(exclude_unset=True)
new_reader = S3Reader.parse_raw(json)
assert new_reader.bucket == reader.bucket
assert new_reader.s3_endpoint_url == reader.s3_endpoint_url
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment