Skip to content
Snippets Groups Projects
Unverified Commit 934d0486 authored by Andrei Fajardo's avatar Andrei Fajardo Committed by GitHub
Browse files

[FIX] download_llama_pack for python packages containing multiple packs (#11272)

* use recursive download to get all files

* fix constants

* remove branch and repo constants

* remove f string
parent 7b937adc
No related branches found
No related tags found
No related merge requests found
...@@ -13,17 +13,15 @@ import requests ...@@ -13,17 +13,15 @@ import requests
from llama_index.core.download.utils import ( from llama_index.core.download.utils import (
ChangeDirectory, ChangeDirectory,
get_file_content, get_file_content,
get_source_files_list,
initialize_directory, initialize_directory,
get_source_files_recursive,
) )
BRANCH = "nerdai/migration-v0_10_0"
REPO = "nerdai"
LLAMA_PACKS_CONTENTS_URL = ( LLAMA_PACKS_CONTENTS_URL = (
f"https://raw.githubusercontent.com/{REPO}/llama_index/{BRANCH}/llama-index-packs" "https://raw.githubusercontent.com/run-llama/llama_index/main/llama-index-packs"
) )
LLAMA_PACKS_SOURCE_FILES_GITHUB_TREE_URL = ( LLAMA_PACKS_SOURCE_FILES_GITHUB_TREE_URL = (
f"https://github.com/{REPO}/llama_index/tree/{BRANCH}/llama-index-packs" "https://github.com/run-llama/llama_index/tree/main"
) )
PY_NAMESPACE = "llama_index/packs" PY_NAMESPACE = "llama_index/packs"
...@@ -51,17 +49,24 @@ def download_module_and_reqs( ...@@ -51,17 +49,24 @@ def download_module_and_reqs(
os.makedirs(module_path, exist_ok=True) os.makedirs(module_path, exist_ok=True)
# download all source files # download all source files
source_files = get_source_files_list( source_files = get_source_files_recursive(
str(remote_source_dir_path), str(remote_source_dir_path),
f"/{package}/{PY_NAMESPACE}/{sub_module}", f"/llama-index-packs/{package}/{PY_NAMESPACE}/{sub_module}",
) )
for source_file in source_files: for source_file in source_files:
source_file_raw_content, _ = get_file_content( source_file_raw_content, _ = get_file_content(
str(remote_dir_path), str(remote_dir_path),
f"/{package}/{PY_NAMESPACE}/{sub_module}/{source_file}", f"{source_file}",
) )
with open(f"{module_path}/{source_file}", "w") as f: local_source_file_path = (
f"{local_dir_path}/{'/'.join(source_file.split('/')[2:])}"
)
# ensure parent dir of file exists
Path(local_source_file_path).parent.absolute().mkdir(
parents=True, exist_ok=True
)
with open(local_source_file_path, "w") as f:
f.write(source_file_raw_content) f.write(source_file_raw_content)
# pyproject.toml and README # pyproject.toml and README
...@@ -99,7 +104,7 @@ def download_llama_pack_template( ...@@ -99,7 +104,7 @@ def download_llama_pack_template(
refresh_cache: bool = False, refresh_cache: bool = False,
custom_dir: Optional[str] = None, custom_dir: Optional[str] = None,
custom_path: Optional[str] = None, custom_path: Optional[str] = None,
base_file_name: str = "base.py", base_file_name: str = "__init__.py",
) -> Any: ) -> Any:
# create directory / get path # create directory / get path
dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir) dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
......
...@@ -95,6 +95,47 @@ def get_source_files_list(source_tree_url: str, path: str) -> List[str]: ...@@ -95,6 +95,47 @@ def get_source_files_list(source_tree_url: str, path: str) -> List[str]:
return [item["name"] for item in payload["tree"]["items"]] return [item["name"] for item in payload["tree"]["items"]]
def recursive_tree_traverse(
tree_urls: List[Tuple[str, str]], acc: List[str], source_tree_url: str
):
"""Recursively traversge Github trees to get all file paths in a folder."""
if not tree_urls:
return acc
else:
url = tree_urls[0]
try:
res = requests.get(url)
tree_elements = res.json()["payload"]["tree"]["items"]
except Exception:
raise ValueError("Failed to traverse github tree source.")
new_trees = [
source_tree_url + "/" + el["path"]
for el in tree_elements
if el["contentType"] == "directory"
]
acc += [
el["path"].replace("llama-index-packs/", "/")
for el in tree_elements
if el["contentType"] == "file"
]
return recursive_tree_traverse(
tree_urls=tree_urls[1:] + new_trees,
acc=acc,
source_tree_url=source_tree_url,
)
def get_source_files_recursive(source_tree_url: str, path: str) -> List[str]:
"""Get source files of a Github folder recursively."""
initial_url = source_tree_url + path + "?recursive=1"
initial_tree_urls = [initial_url]
return recursive_tree_traverse(initial_tree_urls, [], source_tree_url)
class ChangeDirectory: class ChangeDirectory:
"""Context manager for changing the current working directory.""" """Context manager for changing the current working directory."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment