From efe4e372c779ce9236231d42fc800761ddb71bf8 Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Sun, 18 Feb 2024 14:18:35 -0600 Subject: [PATCH] add back GitHub repo file filter (#10949) --- .../readers/github/repository/base.py | 346 +++++++++++++----- .../llama-index-readers-github/pyproject.toml | 2 +- 2 files changed, 251 insertions(+), 97 deletions(-) diff --git a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py index 15f35755c..4c2360667 100644 --- a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py +++ b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py @@ -9,6 +9,7 @@ the text extracted from the files using the parser. import asyncio import base64 import binascii +import enum import logging import os import pathlib @@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from llama_index.core.readers.base import BaseReader from llama_index.core.readers.file.base import _try_loading_included_file_formats from llama_index.core.schema import Document + from llama_index.readers.github.repository.github_client import ( GitBranchResponseModel, GitCommitResponseModel, @@ -32,7 +34,6 @@ from llama_index.readers.github.repository.utils import ( logger = logging.getLogger(__name__) - DEFAULT_FILE_READER_CLS = _try_loading_included_file_formats() @@ -51,56 +52,70 @@ class GithubRepositoryReader(BaseReader): """ + class FilterType(enum.Enum): + """ + Filter type. + + Used to determine whether the filter is inclusive or exclusive. + + Attributes: + - EXCLUDE: Exclude the files in the directories or with the extensions. + - INCLUDE: Include only the files in the directories or with the extensions. + """ + + EXCLUDE = enum.auto() + INCLUDE = enum.auto() + def __init__( self, + github_client: GithubClient, owner: str, repo: str, - use_parser: bool = True, + use_parser: bool = False, verbose: bool = False, - github_token: Optional[str] = None, concurrent_requests: int = 5, - ignore_file_extensions: Optional[List[str]] = None, - ignore_directories: Optional[List[str]] = None, + timeout: Optional[int] = 5, + filter_directories: Optional[Tuple[List[str], FilterType]] = None, + filter_file_extensions: Optional[Tuple[List[str], FilterType]] = None, ): """ Initialize params. Args: + - github_client (BaseGithubClient): Github client. - owner (str): Owner of the repository. - repo (str): Name of the repository. - use_parser (bool): Whether to use the parser to extract the text from the files. - verbose (bool): Whether to print verbose messages. - - github_token (str): Github token. If not provided, - it will be read from the GITHUB_TOKEN environment variable. - concurrent_requests (int): Number of concurrent requests to make to the Github API. - - ignore_file_extensions (List[str]): List of file extensions to ignore. - i.e. ['.png', '.jpg'] - - ignore_directories (List[str]): List of directories to ignore. - i.e. ['node_modules', 'dist'] + - timeout (int or None): Timeout for the requests to the Github API. Default is 5. + - filter_directories (Optional[Tuple[List[str], FilterType]]): Tuple + containing a list of directories and a FilterType. If the FilterType + is INCLUDE, only the files in the directories in the list will be + included. If the FilterType is EXCLUDE, the files in the directories + in the list will be excluded. + - filter_file_extensions (Optional[Tuple[List[str], FilterType]]): Tuple + containing a list of file extensions and a FilterType. If the + FilterType is INCLUDE, only the files with the extensions in the list + will be included. If the FilterType is EXCLUDE, the files with the + extensions in the list will be excluded. Raises: - `ValueError`: If the github_token is not provided and the GITHUB_TOKEN environment variable is not set. """ super().__init__() - if github_token is None: - github_token = os.getenv("GITHUB_TOKEN") - if github_token is None: - raise ValueError( - "Please provide a Github token. " - "You can do so by passing it as an argument or " - + "by setting the GITHUB_TOKEN environment variable." - ) self._owner = owner self._repo = repo self._use_parser = use_parser self._verbose = verbose self._concurrent_requests = concurrent_requests - self._ignore_file_extensions = ignore_file_extensions - self._ignore_directories = ignore_directories + self._timeout = timeout + self._filter_directories = filter_directories + self._filter_file_extensions = filter_file_extensions # Set up the event loop try: @@ -110,11 +125,98 @@ class GithubRepositoryReader(BaseReader): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self._client = GithubClient(github_token) + self._github_client = github_client self._file_readers: Dict[str, BaseReader] = {} self._supported_suffix = list(DEFAULT_FILE_READER_CLS.keys()) + def _check_filter_directories(self, tree_obj_path: str) -> bool: + """ + Check if a tree object should be allowed based on the directories. + + :param `tree_obj_path`: path of the tree object i.e. 'llama_index/readers' + + :return: True if the tree object should be allowed, False otherwise + """ + if self._filter_directories is None: + return True + filter_directories, filter_type = self._filter_directories + print_if_verbose( + self._verbose, + f"Checking {tree_obj_path} whether to {filter_type} it" + + f" based on the filter directories: {filter_directories}", + ) + + if filter_type == self.FilterType.EXCLUDE: + print_if_verbose( + self._verbose, + f"Checking if {tree_obj_path} is not a subdirectory of any of the" + " filter directories", + ) + return not any( + tree_obj_path.startswith(directory) for directory in filter_directories + ) + if filter_type == self.FilterType.INCLUDE: + print_if_verbose( + self._verbose, + f"Checking if {tree_obj_path} is a subdirectory of any of the filter" + " directories", + ) + return any( + tree_obj_path.startswith(directory) + or directory.startswith(tree_obj_path) + for directory in filter_directories + ) + raise ValueError( + f"Unknown filter type: {filter_type}. " + "Please use either 'INCLUDE' or 'EXCLUDE'." + ) + + def _check_filter_file_extensions(self, tree_obj_path: str) -> bool: + """ + Check if a tree object should be allowed based on the file extensions. + + :param `tree_obj_path`: path of the tree object i.e. 'llama_index/indices' + + :return: True if the tree object should be allowed, False otherwise + """ + if self._filter_file_extensions is None: + return True + filter_file_extensions, filter_type = self._filter_file_extensions + print_if_verbose( + self._verbose, + f"Checking {tree_obj_path} whether to {filter_type} it" + + f" based on the filter file extensions: {filter_file_extensions}", + ) + + if filter_type == self.FilterType.EXCLUDE: + return get_file_extension(tree_obj_path) not in filter_file_extensions + if filter_type == self.FilterType.INCLUDE: + return get_file_extension(tree_obj_path) in filter_file_extensions + raise ValueError( + f"Unknown filter type: {filter_type}. " + "Please use either 'INCLUDE' or 'EXCLUDE'." + ) + + def _allow_tree_obj(self, tree_obj_path: str, tree_obj_type: str) -> bool: + """ + Check if a tree object should be allowed. + + :param `tree_obj_path`: path of the tree object + + :return: True if the tree object should be allowed, False otherwise + + """ + if self._filter_directories is not None and tree_obj_type == "tree": + return self._check_filter_directories(tree_obj_path) + + if self._filter_file_extensions is not None and tree_obj_type == "blob": + return self._check_filter_directories( + tree_obj_path + ) and self._check_filter_file_extensions(tree_obj_path) + + return True + def _load_data_from_commit(self, commit_sha: str) -> List[Document]: """ Load data from a commit. @@ -126,7 +228,9 @@ class GithubRepositoryReader(BaseReader): :return: list of documents """ commit_response: GitCommitResponseModel = self._loop.run_until_complete( - self._client.get_commit(self._owner, self._repo, commit_sha) + self._github_client.get_commit( + self._owner, self._repo, commit_sha, timeout=self._timeout + ) ) tree_sha = commit_response.commit.tree.sha @@ -135,7 +239,7 @@ class GithubRepositoryReader(BaseReader): print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs") return self._loop.run_until_complete( - self._generate_documents(blobs_and_paths=blobs_and_paths) + self._generate_documents(blobs_and_paths=blobs_and_paths, id=commit_sha) ) def _load_data_from_branch(self, branch: str) -> List[Document]: @@ -149,7 +253,9 @@ class GithubRepositoryReader(BaseReader): :return: list of documents """ branch_data: GitBranchResponseModel = self._loop.run_until_complete( - self._client.get_branch(self._owner, self._repo, branch) + self._github_client.get_branch( + self._owner, self._repo, branch, timeout=self._timeout + ) ) tree_sha = branch_data.commit.commit.tree.sha @@ -158,7 +264,7 @@ class GithubRepositoryReader(BaseReader): print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs") return self._loop.run_until_complete( - self._generate_documents(blobs_and_paths=blobs_and_paths) + self._generate_documents(blobs_and_paths=blobs_and_paths, id=branch) ) def load_data( @@ -191,7 +297,11 @@ class GithubRepositoryReader(BaseReader): raise ValueError("You must specify one of commit or branch.") async def _recurse_tree( - self, tree_sha: str, current_path: str = "", current_depth: int = 0 + self, + tree_sha: str, + current_path: str = "", + current_depth: int = 0, + max_depth: int = -1, ) -> Any: """ Recursively get all blob tree objects in a tree. @@ -206,64 +316,79 @@ class GithubRepositoryReader(BaseReader): :return: list of tuples of (tree object, file's full path relative to the root of the repo) """ + if max_depth != -1 and current_depth > max_depth: + return [] + blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = [] print_if_verbose( - self._verbose, "\t" * current_depth + f"current path: {current_path}" + self._verbose, + "\t" * current_depth + f"current path: {current_path}", ) - tree_data: GitTreeResponseModel = await self._client.get_tree( - self._owner, self._repo, tree_sha + tree_data: GitTreeResponseModel = await self._github_client.get_tree( + self._owner, self._repo, tree_sha, timeout=self._timeout + ) + print_if_verbose( + self._verbose, "\t" * current_depth + f"tree data: {tree_data}" ) print_if_verbose( self._verbose, "\t" * current_depth + f"processing tree {tree_sha}" ) for tree_obj in tree_data.tree: file_path = os.path.join(current_path, tree_obj.path) + if not self._allow_tree_obj(file_path, tree_obj.type): + print_if_verbose( + self._verbose, + "\t" * current_depth + f"ignoring {tree_obj.path} due to filter", + ) + continue + + print_if_verbose( + self._verbose, + "\t" * current_depth + f"tree object: {tree_obj}", + ) + if tree_obj.type == "tree": print_if_verbose( self._verbose, "\t" * current_depth + f"recursing into {tree_obj.path}", ) - if self._ignore_directories is not None: - if tree_obj.path in self._ignore_directories: - print_if_verbose( - self._verbose, - "\t" * current_depth - + f"ignoring tree {tree_obj.path} due to directory", - ) - continue blobs_and_full_paths.extend( - await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1) + await self._recurse_tree( + tree_obj.sha, file_path, current_depth + 1, max_depth + ) ) elif tree_obj.type == "blob": print_if_verbose( - self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}" + self._verbose, + "\t" * current_depth + f"found blob {tree_obj.path}", ) - if self._ignore_file_extensions is not None: - if get_file_extension(file_path) in self._ignore_file_extensions: - print_if_verbose( - self._verbose, - "\t" * current_depth - + f"ignoring blob {tree_obj.path} due to file extension", - ) - continue + blobs_and_full_paths.append((tree_obj, file_path)) + + print_if_verbose( + self._verbose, + "\t" * current_depth + f"blob and full paths: {blobs_and_full_paths}", + ) return blobs_and_full_paths async def _generate_documents( - self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] + self, + blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]], + id: str = "", ) -> List[Document]: """ Generate documents from a list of blobs and their full paths. :param `blobs_and_paths`: list of tuples of (tree object, file's full path in the repo relative to the root of the repo) + :param `id`: the branch name or commit sha used when loading the repo :return: list of documents """ buffered_iterator = BufferedGitBlobDataIterator( blobs_and_paths=blobs_and_paths, - github_client=self._client, + github_client=self._github_client, owner=self._owner, repo=self._repo, loop=self._loop, @@ -297,6 +422,11 @@ class GithubRepositoryReader(BaseReader): if document is not None: documents.append(document) continue + print_if_verbose( + self._verbose, + f"could not parse {full_path} as a supported file type" + + " - falling back to decoding as utf-8 raw text", + ) try: if decoded_bytes is None: @@ -312,19 +442,27 @@ class GithubRepositoryReader(BaseReader): f"got {len(decoded_text)} characters" + f"- adding to documents - {full_path}", ) + url = os.path.join( + "https://github.com/", self._owner, self._repo, "blob/", id, full_path + ) document = Document( text=decoded_text, - id_=blob_data.sha, - metadata={ + doc_id=blob_data.sha, + extra_info={ "file_path": full_path, "file_name": full_path.split("/")[-1], + "url": url, }, ) documents.append(document) return documents def _parse_supported_file( - self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str + self, + file_path: str, + file_content: bytes, + tree_sha: str, + tree_path: str, ) -> Optional[Document]: """ Parse a file if it is supported by a parser. @@ -351,42 +489,44 @@ class GithubRepositoryReader(BaseReader): + f"as {file_extension} with " + f"{reader.__class__.__name__}", ) - with tempfile.TemporaryDirectory() as tmpdirname, tempfile.NamedTemporaryFile( - dir=tmpdirname, - suffix=f".{file_extension}", - mode="w+b", - delete=False, - ) as tmpfile: - print_if_verbose( - self._verbose, - "created a temporary file" + f"{tmpfile.name} for parsing {file_path}", - ) - tmpfile.write(file_content) - tmpfile.flush() - tmpfile.close() - try: - docs = reader.load_data(pathlib.Path(tmpfile.name)) - parsed_file = "\n\n".join([doc.get_content() for doc in docs]) - except Exception as e: - print_if_verbose(self._verbose, f"error while parsing {file_path}") - logger.error( - "Error while parsing " - + f"{file_path} with " - + f"{reader.__class__.__name__}:\n{e}" + with tempfile.TemporaryDirectory() as tmpdirname: + with tempfile.NamedTemporaryFile( + dir=tmpdirname, + suffix=f".{file_extension}", + mode="w+b", + delete=False, + ) as tmpfile: + print_if_verbose( + self._verbose, + "created a temporary file" + + f"{tmpfile.name} for parsing {file_path}", + ) + tmpfile.write(file_content) + tmpfile.flush() + tmpfile.close() + try: + docs = reader.load_data(pathlib.Path(tmpfile.name)) + parsed_file = "\n\n".join([doc.get_text() for doc in docs]) + except Exception as e: + print_if_verbose(self._verbose, f"error while parsing {file_path}") + logger.error( + "Error while parsing " + + f"{file_path} with " + + f"{reader.__class__.__name__}:\n{e}" + ) + parsed_file = None + finally: + os.remove(tmpfile.name) + if parsed_file is None: + return None + return Document( + text=parsed_file, + doc_id=tree_sha, + extra_info={ + "file_path": file_path, + "file_name": tree_path, + }, ) - parsed_file = None - finally: - os.remove(tmpfile.name) - if parsed_file is None: - return None - return Document( - text=parsed_file, - id_=tree_sha, - metadata={ - "file_path": file_path, - "file_name": tree_path, - }, - ) if __name__ == "__main__": @@ -404,13 +544,31 @@ if __name__ == "__main__": return wrapper + github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True) + reader1 = GithubRepositoryReader( - github_token=os.environ["GITHUB_TOKEN"], + github_client=github_client, owner="jerryjliu", repo="llama_index", use_parser=False, verbose=True, - ignore_directories=["examples"], + filter_directories=( + ["docs"], + GithubRepositoryReader.FilterType.INCLUDE, + ), + filter_file_extensions=( + [ + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".ico", + "json", + ".ipynb", + ], + GithubRepositoryReader.FilterType.EXCLUDE, + ), ) @timeit @@ -420,19 +578,15 @@ if __name__ == "__main__": commit_sha="22e198b3b166b5facd2843d6a62ac0db07894a13" ) for document in documents: - print(document.metadata) + print(document.extra_info) @timeit def load_data_from_branch() -> None: """Load data from a branch.""" documents = reader1.load_data(branch="main") for document in documents: - print(document.metadata) + print(document.extra_info) input("Press enter to load github repository from branch name...") load_data_from_branch() - - input("Press enter to load github repository from commit sha...") - - load_data_from_commit() diff --git a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml index e60346bb3..adf90af5d 100644 --- a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml @@ -26,7 +26,7 @@ license = "MIT" maintainers = ["ahmetkca", "moncho", "rwood-97"] name = "llama-index-readers-github" readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] python = ">=3.8.1,<3.12" -- GitLab