Skip to content
Snippets Groups Projects
Unverified Commit 10d7b12c authored by ahmetkca's avatar ahmetkca Committed by GitHub
Browse files

Refactor download_loader to how loaders are downloaded. (#439)

parent a8015893
No related branches found
No related tags found
No related merge requests found
......@@ -137,3 +137,4 @@ dmypy.json
# Jetbrains
.idea
modules/
\ No newline at end of file
......@@ -6,7 +6,7 @@ import subprocess
import sys
from importlib import util
from pathlib import Path
from typing import Optional
from typing import List, Optional, Tuple
import pkg_resources
import requests
......@@ -14,9 +14,66 @@ from pkg_resources import DistributionNotFound
from gpt_index.readers.base import BaseReader
LOADER_HUB_URL = (
"https://raw.githubusercontent.com/emptycrown/loader-hub/main/loader_hub"
)
LLAMA_HUB_CONTENTS_URL = "https://raw.githubusercontent.com/emptycrown/loader-hub/main"
LOADER_HUB_PATH = "/loader_hub"
LOADER_HUB_URL = LLAMA_HUB_CONTENTS_URL + LOADER_HUB_PATH
def _get_file_content(loader_hub_url: str, path: str) -> Tuple[str, int]:
"""Get the content of a file from the GitHub REST API."""
resp = requests.get(loader_hub_url + path)
return resp.text, resp.status_code
def get_exports(raw_content: str) -> List:
"""Read content of a Python file and returns a list of exported class names.
For example:
```python
from .a import A
from .b import B
__all__ = ["A", "B"]
```
will return `["A", "B"]`.
Args:
- raw_content: The content of a Python file as a string.
Returns:
A list of exported class names.
"""
exports = []
for line in raw_content.splitlines():
line = line.strip()
if line.startswith("__all__"):
exports = line.split("=")[1].strip().strip("[").strip("]").split(",")
exports = [export.strip().strip("'").strip('"') for export in exports]
return exports
def rewrite_exports(exports: List[str]) -> None:
"""Write the `__all__` variable to the `__init__.py` file in the modules dir.
Removes the line that contains `__all__` and appends a new line with the updated
`__all__` variable.
Args:
- exports: A list of exported class names.
"""
dirpath = Path(__file__).parent / "llamahub_modules"
init_path = f"{dirpath}/__init__.py"
with open(init_path, "r") as f:
lines = f.readlines()
with open(init_path, "w") as f:
for line in lines:
line = line.strip()
if line.startswith("__all__"):
continue
f.write(line + os.linesep)
f.write(f"__all__ = {list(set(exports))}" + os.linesep)
def download_loader(
......@@ -41,13 +98,14 @@ def download_loader(
Returns:
A Loader.
"""
dirpath = ".modules"
dirpath = Path(__file__).parent / "llamahub_modules"
if not os.path.exists(dirpath):
# Create a new directory because it does not exist
os.makedirs(dirpath)
library_path = f"{dirpath}/library.json"
loader_id = None # e.g. `web/simple_web`
extra_files = [] # e.g. `web/simple_web/utils.py`
# Check cache first
if not refresh_cache and os.path.exists(library_path):
......@@ -55,41 +113,70 @@ def download_loader(
library = json.load(f)
if loader_class in library:
loader_id = library[loader_class]["id"]
extra_files = library[loader_class].get("extra_files", [])
# Fetch up-to-date library from remote repo if loader_id not found
if loader_id is None:
response = requests.get(f"{loader_hub_url}/library.json")
library = json.loads(response.text)
library_raw_content, _ = _get_file_content(loader_hub_url, "/library.json")
library = json.loads(library_raw_content)
if loader_class not in library:
raise ValueError("Loader class name not found in library")
loader_id = library[loader_class]["id"]
extra_files = library[loader_class].get("extra_files", [])
# Update cache
with open(library_path, "w") as f:
f.write(response.text)
f.write(library_raw_content)
assert loader_id is not None
if loader_id is None:
raise ValueError("Loader class name not found in library")
# Load the module
loader_filename = loader_id.replace("/", "-")
loader_path = f"{dirpath}/{loader_filename}.py"
requirements_path = f"{dirpath}/{loader_filename}_requirements.txt"
loader_path = f"{dirpath}/{loader_id}"
requirements_path = f"{loader_path}/requirements.txt"
if refresh_cache or not os.path.exists(loader_path):
response = requests.get(f"{loader_hub_url}/{loader_id}/base.py")
response_text = response.text
os.makedirs(loader_path, exist_ok=True)
basepy_raw_content, _ = _get_file_content(
loader_hub_url, f"/{loader_id}/base.py"
)
if use_gpt_index_import:
response_text = response_text.replace(
basepy_raw_content = basepy_raw_content.replace(
"import llama_index", "import gpt_index"
)
response_text = response_text.replace("from llama_index", "from gpt_index")
with open(loader_path, "w") as f:
f.write(response_text)
basepy_raw_content = basepy_raw_content.replace(
"from llama_index", "from gpt_index"
)
with open(f"{loader_path}/base.py", "w") as f:
f.write(basepy_raw_content)
# Get content of extra files if there are any
# and write them under the loader directory
for extra_file in extra_files:
extra_file_raw_content, _ = _get_file_content(
loader_hub_url, f"/{loader_id}/{extra_file}"
)
# If the extra file is an __init__.py file, we need to
# add the exports to the __init__.py file in the modules directory
if extra_file == "__init__.py":
loader_exports = get_exports(extra_file_raw_content)
existing_exports = []
if os.path.exists(dirpath / "__init__.py"):
with open(dirpath / "__init__.py", "r+") as f:
f.write(f"from .{loader_id} import {', '.join(loader_exports)}")
existing_exports = get_exports(f.read())
rewrite_exports(existing_exports + loader_exports)
with open(f"{loader_path}/{extra_file}", "w") as f:
f.write(extra_file_raw_content)
if not os.path.exists(requirements_path):
response = requests.get(f"{loader_hub_url}/{loader_id}/requirements.txt")
if response.status_code == 200:
# NOTE: need to check the status code
response_txt, status_code = _get_file_content(
loader_hub_url, f"/{loader_id}/requirements.txt"
)
if status_code == 200:
with open(requirements_path, "w") as f:
f.write(response.text)
f.write(response_txt)
# Install dependencies if there are any and not already installed
if os.path.exists(requirements_path):
......@@ -102,10 +189,11 @@ def download_loader(
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_path]
)
spec = util.spec_from_file_location("custom_loader", location=loader_path)
spec = util.spec_from_file_location(
"custom_loader", location=f"{loader_path}/base.py"
)
if spec is None:
raise ValueError(f"Could not find file: {loader_path}.")
raise ValueError(f"Could not find file: {loader_path}/base.py.")
module = util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment