From efe4e372c779ce9236231d42fc800761ddb71bf8 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Sun, 18 Feb 2024 14:18:35 -0600
Subject: [PATCH] add back GitHub repo file filter (#10949)

---
 .../readers/github/repository/base.py         | 346 +++++++++++++-----
 .../llama-index-readers-github/pyproject.toml |   2 +-
 2 files changed, 251 insertions(+), 97 deletions(-)

diff --git a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py
index 15f35755c..4c2360667 100644
--- a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py
+++ b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py
@@ -9,6 +9,7 @@ the text extracted from the files using the parser.
 import asyncio
 import base64
 import binascii
+import enum
 import logging
 import os
 import pathlib
@@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
 from llama_index.core.readers.base import BaseReader
 from llama_index.core.readers.file.base import _try_loading_included_file_formats
 from llama_index.core.schema import Document
+
 from llama_index.readers.github.repository.github_client import (
     GitBranchResponseModel,
     GitCommitResponseModel,
@@ -32,7 +34,6 @@ from llama_index.readers.github.repository.utils import (
 
 logger = logging.getLogger(__name__)
 
-
 DEFAULT_FILE_READER_CLS = _try_loading_included_file_formats()
 
 
@@ -51,56 +52,70 @@ class GithubRepositoryReader(BaseReader):
 
     """
 
+    class FilterType(enum.Enum):
+        """
+        Filter type.
+
+        Used to determine whether the filter is inclusive or exclusive.
+
+        Attributes:
+            - EXCLUDE: Exclude the files in the directories or with the extensions.
+            - INCLUDE: Include only the files in the directories or with the extensions.
+        """
+
+        EXCLUDE = enum.auto()
+        INCLUDE = enum.auto()
+
     def __init__(
         self,
+        github_client: GithubClient,
         owner: str,
         repo: str,
-        use_parser: bool = True,
+        use_parser: bool = False,
         verbose: bool = False,
-        github_token: Optional[str] = None,
         concurrent_requests: int = 5,
-        ignore_file_extensions: Optional[List[str]] = None,
-        ignore_directories: Optional[List[str]] = None,
+        timeout: Optional[int] = 5,
+        filter_directories: Optional[Tuple[List[str], FilterType]] = None,
+        filter_file_extensions: Optional[Tuple[List[str], FilterType]] = None,
     ):
         """
         Initialize params.
 
         Args:
+            - github_client (BaseGithubClient): Github client.
             - owner (str): Owner of the repository.
             - repo (str): Name of the repository.
             - use_parser (bool): Whether to use the parser to extract
                 the text from the files.
             - verbose (bool): Whether to print verbose messages.
-            - github_token (str): Github token. If not provided,
-                it will be read from the GITHUB_TOKEN environment variable.
             - concurrent_requests (int): Number of concurrent requests to
                 make to the Github API.
-            - ignore_file_extensions (List[str]): List of file extensions to ignore.
-                i.e. ['.png', '.jpg']
-            - ignore_directories (List[str]): List of directories to ignore.
-                i.e. ['node_modules', 'dist']
+            - timeout (int or None): Timeout for the requests to the Github API. Default is 5.
+            - filter_directories (Optional[Tuple[List[str], FilterType]]): Tuple
+                containing a list of directories and a FilterType. If the FilterType
+                is INCLUDE, only the files in the directories in the list will be
+                included. If the FilterType is EXCLUDE, the files in the directories
+                in the list will be excluded.
+            - filter_file_extensions (Optional[Tuple[List[str], FilterType]]): Tuple
+                containing a list of file extensions and a FilterType. If the
+                FilterType is INCLUDE, only the files with the extensions in the list
+                will be included. If the FilterType is EXCLUDE, the files with the
+                extensions in the list will be excluded.
 
         Raises:
             - `ValueError`: If the github_token is not provided and
                 the GITHUB_TOKEN environment variable is not set.
         """
         super().__init__()
-        if github_token is None:
-            github_token = os.getenv("GITHUB_TOKEN")
-            if github_token is None:
-                raise ValueError(
-                    "Please provide a Github token. "
-                    "You can do so by passing it as an argument or "
-                    + "by setting the GITHUB_TOKEN environment variable."
-                )
 
         self._owner = owner
         self._repo = repo
         self._use_parser = use_parser
         self._verbose = verbose
         self._concurrent_requests = concurrent_requests
-        self._ignore_file_extensions = ignore_file_extensions
-        self._ignore_directories = ignore_directories
+        self._timeout = timeout
+        self._filter_directories = filter_directories
+        self._filter_file_extensions = filter_file_extensions
 
         # Set up the event loop
         try:
@@ -110,11 +125,98 @@ class GithubRepositoryReader(BaseReader):
             self._loop = asyncio.new_event_loop()
             asyncio.set_event_loop(self._loop)
 
-        self._client = GithubClient(github_token)
+        self._github_client = github_client
 
         self._file_readers: Dict[str, BaseReader] = {}
         self._supported_suffix = list(DEFAULT_FILE_READER_CLS.keys())
 
+    def _check_filter_directories(self, tree_obj_path: str) -> bool:
+        """
+        Check if a tree object should be allowed based on the directories.
+
+        :param `tree_obj_path`: path of the tree object i.e. 'llama_index/readers'
+
+        :return: True if the tree object should be allowed, False otherwise
+        """
+        if self._filter_directories is None:
+            return True
+        filter_directories, filter_type = self._filter_directories
+        print_if_verbose(
+            self._verbose,
+            f"Checking {tree_obj_path} whether to {filter_type} it"
+            + f" based on the filter directories: {filter_directories}",
+        )
+
+        if filter_type == self.FilterType.EXCLUDE:
+            print_if_verbose(
+                self._verbose,
+                f"Checking if {tree_obj_path} is not a subdirectory of any of the"
+                " filter directories",
+            )
+            return not any(
+                tree_obj_path.startswith(directory) for directory in filter_directories
+            )
+        if filter_type == self.FilterType.INCLUDE:
+            print_if_verbose(
+                self._verbose,
+                f"Checking if {tree_obj_path} is a subdirectory of any of the filter"
+                " directories",
+            )
+            return any(
+                tree_obj_path.startswith(directory)
+                or directory.startswith(tree_obj_path)
+                for directory in filter_directories
+            )
+        raise ValueError(
+            f"Unknown filter type: {filter_type}. "
+            "Please use either 'INCLUDE' or 'EXCLUDE'."
+        )
+
+    def _check_filter_file_extensions(self, tree_obj_path: str) -> bool:
+        """
+        Check if a tree object should be allowed based on the file extensions.
+
+        :param `tree_obj_path`: path of the tree object i.e. 'llama_index/indices'
+
+        :return: True if the tree object should be allowed, False otherwise
+        """
+        if self._filter_file_extensions is None:
+            return True
+        filter_file_extensions, filter_type = self._filter_file_extensions
+        print_if_verbose(
+            self._verbose,
+            f"Checking {tree_obj_path} whether to {filter_type} it"
+            + f" based on the filter file extensions: {filter_file_extensions}",
+        )
+
+        if filter_type == self.FilterType.EXCLUDE:
+            return get_file_extension(tree_obj_path) not in filter_file_extensions
+        if filter_type == self.FilterType.INCLUDE:
+            return get_file_extension(tree_obj_path) in filter_file_extensions
+        raise ValueError(
+            f"Unknown filter type: {filter_type}. "
+            "Please use either 'INCLUDE' or 'EXCLUDE'."
+        )
+
+    def _allow_tree_obj(self, tree_obj_path: str, tree_obj_type: str) -> bool:
+        """
+        Check if a tree object should be allowed.
+
+        :param `tree_obj_path`: path of the tree object
+
+        :return: True if the tree object should be allowed, False otherwise
+
+        """
+        if self._filter_directories is not None and tree_obj_type == "tree":
+            return self._check_filter_directories(tree_obj_path)
+
+        if self._filter_file_extensions is not None and tree_obj_type == "blob":
+            return self._check_filter_directories(
+                tree_obj_path
+            ) and self._check_filter_file_extensions(tree_obj_path)
+
+        return True
+
     def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
         """
         Load data from a commit.
@@ -126,7 +228,9 @@ class GithubRepositoryReader(BaseReader):
         :return: list of documents
         """
         commit_response: GitCommitResponseModel = self._loop.run_until_complete(
-            self._client.get_commit(self._owner, self._repo, commit_sha)
+            self._github_client.get_commit(
+                self._owner, self._repo, commit_sha, timeout=self._timeout
+            )
         )
 
         tree_sha = commit_response.commit.tree.sha
@@ -135,7 +239,7 @@ class GithubRepositoryReader(BaseReader):
         print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
 
         return self._loop.run_until_complete(
-            self._generate_documents(blobs_and_paths=blobs_and_paths)
+            self._generate_documents(blobs_and_paths=blobs_and_paths, id=commit_sha)
         )
 
     def _load_data_from_branch(self, branch: str) -> List[Document]:
@@ -149,7 +253,9 @@ class GithubRepositoryReader(BaseReader):
         :return: list of documents
         """
         branch_data: GitBranchResponseModel = self._loop.run_until_complete(
-            self._client.get_branch(self._owner, self._repo, branch)
+            self._github_client.get_branch(
+                self._owner, self._repo, branch, timeout=self._timeout
+            )
         )
 
         tree_sha = branch_data.commit.commit.tree.sha
@@ -158,7 +264,7 @@ class GithubRepositoryReader(BaseReader):
         print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
 
         return self._loop.run_until_complete(
-            self._generate_documents(blobs_and_paths=blobs_and_paths)
+            self._generate_documents(blobs_and_paths=blobs_and_paths, id=branch)
         )
 
     def load_data(
@@ -191,7 +297,11 @@ class GithubRepositoryReader(BaseReader):
         raise ValueError("You must specify one of commit or branch.")
 
     async def _recurse_tree(
-        self, tree_sha: str, current_path: str = "", current_depth: int = 0
+        self,
+        tree_sha: str,
+        current_path: str = "",
+        current_depth: int = 0,
+        max_depth: int = -1,
     ) -> Any:
         """
         Recursively get all blob tree objects in a tree.
@@ -206,64 +316,79 @@ class GithubRepositoryReader(BaseReader):
         :return: list of tuples of
             (tree object, file's full path relative to the root of the repo)
         """
+        if max_depth != -1 and current_depth > max_depth:
+            return []
+
         blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = []
         print_if_verbose(
-            self._verbose, "\t" * current_depth + f"current path: {current_path}"
+            self._verbose,
+            "\t" * current_depth + f"current path: {current_path}",
         )
 
-        tree_data: GitTreeResponseModel = await self._client.get_tree(
-            self._owner, self._repo, tree_sha
+        tree_data: GitTreeResponseModel = await self._github_client.get_tree(
+            self._owner, self._repo, tree_sha, timeout=self._timeout
+        )
+        print_if_verbose(
+            self._verbose, "\t" * current_depth + f"tree data: {tree_data}"
         )
         print_if_verbose(
             self._verbose, "\t" * current_depth + f"processing tree {tree_sha}"
         )
         for tree_obj in tree_data.tree:
             file_path = os.path.join(current_path, tree_obj.path)
+            if not self._allow_tree_obj(file_path, tree_obj.type):
+                print_if_verbose(
+                    self._verbose,
+                    "\t" * current_depth + f"ignoring {tree_obj.path} due to filter",
+                )
+                continue
+
+            print_if_verbose(
+                self._verbose,
+                "\t" * current_depth + f"tree object: {tree_obj}",
+            )
+
             if tree_obj.type == "tree":
                 print_if_verbose(
                     self._verbose,
                     "\t" * current_depth + f"recursing into {tree_obj.path}",
                 )
-                if self._ignore_directories is not None:
-                    if tree_obj.path in self._ignore_directories:
-                        print_if_verbose(
-                            self._verbose,
-                            "\t" * current_depth
-                            + f"ignoring tree {tree_obj.path} due to directory",
-                        )
-                        continue
 
                 blobs_and_full_paths.extend(
-                    await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1)
+                    await self._recurse_tree(
+                        tree_obj.sha, file_path, current_depth + 1, max_depth
+                    )
                 )
             elif tree_obj.type == "blob":
                 print_if_verbose(
-                    self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}"
+                    self._verbose,
+                    "\t" * current_depth + f"found blob {tree_obj.path}",
                 )
-                if self._ignore_file_extensions is not None:
-                    if get_file_extension(file_path) in self._ignore_file_extensions:
-                        print_if_verbose(
-                            self._verbose,
-                            "\t" * current_depth
-                            + f"ignoring blob {tree_obj.path} due to file extension",
-                        )
-                        continue
+
                 blobs_and_full_paths.append((tree_obj, file_path))
+
+            print_if_verbose(
+                self._verbose,
+                "\t" * current_depth + f"blob and full paths: {blobs_and_full_paths}",
+            )
         return blobs_and_full_paths
 
     async def _generate_documents(
-        self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]]
+        self,
+        blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]],
+        id: str = "",
     ) -> List[Document]:
         """
         Generate documents from a list of blobs and their full paths.
 
         :param `blobs_and_paths`: list of tuples of
             (tree object, file's full path in the repo relative to the root of the repo)
+        :param `id`: the branch name or commit sha used when loading the repo
         :return: list of documents
         """
         buffered_iterator = BufferedGitBlobDataIterator(
             blobs_and_paths=blobs_and_paths,
-            github_client=self._client,
+            github_client=self._github_client,
             owner=self._owner,
             repo=self._repo,
             loop=self._loop,
@@ -297,6 +422,11 @@ class GithubRepositoryReader(BaseReader):
                 if document is not None:
                     documents.append(document)
                     continue
+                print_if_verbose(
+                    self._verbose,
+                    f"could not parse {full_path} as a supported file type"
+                    + " - falling back to decoding as utf-8 raw text",
+                )
 
             try:
                 if decoded_bytes is None:
@@ -312,19 +442,27 @@ class GithubRepositoryReader(BaseReader):
                 f"got {len(decoded_text)} characters"
                 + f"- adding to documents - {full_path}",
             )
+            url = os.path.join(
+                "https://github.com/", self._owner, self._repo, "blob/", id, full_path
+            )
             document = Document(
                 text=decoded_text,
-                id_=blob_data.sha,
-                metadata={
+                doc_id=blob_data.sha,
+                extra_info={
                     "file_path": full_path,
                     "file_name": full_path.split("/")[-1],
+                    "url": url,
                 },
             )
             documents.append(document)
         return documents
 
     def _parse_supported_file(
-        self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str
+        self,
+        file_path: str,
+        file_content: bytes,
+        tree_sha: str,
+        tree_path: str,
     ) -> Optional[Document]:
         """
         Parse a file if it is supported by a parser.
@@ -351,42 +489,44 @@ class GithubRepositoryReader(BaseReader):
             + f"as {file_extension} with "
             + f"{reader.__class__.__name__}",
         )
-        with tempfile.TemporaryDirectory() as tmpdirname, tempfile.NamedTemporaryFile(
-            dir=tmpdirname,
-            suffix=f".{file_extension}",
-            mode="w+b",
-            delete=False,
-        ) as tmpfile:
-            print_if_verbose(
-                self._verbose,
-                "created a temporary file" + f"{tmpfile.name} for parsing {file_path}",
-            )
-            tmpfile.write(file_content)
-            tmpfile.flush()
-            tmpfile.close()
-            try:
-                docs = reader.load_data(pathlib.Path(tmpfile.name))
-                parsed_file = "\n\n".join([doc.get_content() for doc in docs])
-            except Exception as e:
-                print_if_verbose(self._verbose, f"error while parsing {file_path}")
-                logger.error(
-                    "Error while parsing "
-                    + f"{file_path} with "
-                    + f"{reader.__class__.__name__}:\n{e}"
+        with tempfile.TemporaryDirectory() as tmpdirname:
+            with tempfile.NamedTemporaryFile(
+                dir=tmpdirname,
+                suffix=f".{file_extension}",
+                mode="w+b",
+                delete=False,
+            ) as tmpfile:
+                print_if_verbose(
+                    self._verbose,
+                    "created a temporary file"
+                    + f"{tmpfile.name} for parsing {file_path}",
+                )
+                tmpfile.write(file_content)
+                tmpfile.flush()
+                tmpfile.close()
+                try:
+                    docs = reader.load_data(pathlib.Path(tmpfile.name))
+                    parsed_file = "\n\n".join([doc.get_text() for doc in docs])
+                except Exception as e:
+                    print_if_verbose(self._verbose, f"error while parsing {file_path}")
+                    logger.error(
+                        "Error while parsing "
+                        + f"{file_path} with "
+                        + f"{reader.__class__.__name__}:\n{e}"
+                    )
+                    parsed_file = None
+                finally:
+                    os.remove(tmpfile.name)
+                if parsed_file is None:
+                    return None
+                return Document(
+                    text=parsed_file,
+                    doc_id=tree_sha,
+                    extra_info={
+                        "file_path": file_path,
+                        "file_name": tree_path,
+                    },
                 )
-                parsed_file = None
-            finally:
-                os.remove(tmpfile.name)
-            if parsed_file is None:
-                return None
-            return Document(
-                text=parsed_file,
-                id_=tree_sha,
-                metadata={
-                    "file_path": file_path,
-                    "file_name": tree_path,
-                },
-            )
 
 
 if __name__ == "__main__":
@@ -404,13 +544,31 @@ if __name__ == "__main__":
 
         return wrapper
 
+    github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True)
+
     reader1 = GithubRepositoryReader(
-        github_token=os.environ["GITHUB_TOKEN"],
+        github_client=github_client,
         owner="jerryjliu",
         repo="llama_index",
         use_parser=False,
         verbose=True,
-        ignore_directories=["examples"],
+        filter_directories=(
+            ["docs"],
+            GithubRepositoryReader.FilterType.INCLUDE,
+        ),
+        filter_file_extensions=(
+            [
+                ".png",
+                ".jpg",
+                ".jpeg",
+                ".gif",
+                ".svg",
+                ".ico",
+                "json",
+                ".ipynb",
+            ],
+            GithubRepositoryReader.FilterType.EXCLUDE,
+        ),
     )
 
     @timeit
@@ -420,19 +578,15 @@ if __name__ == "__main__":
             commit_sha="22e198b3b166b5facd2843d6a62ac0db07894a13"
         )
         for document in documents:
-            print(document.metadata)
+            print(document.extra_info)
 
     @timeit
     def load_data_from_branch() -> None:
         """Load data from a branch."""
         documents = reader1.load_data(branch="main")
         for document in documents:
-            print(document.metadata)
+            print(document.extra_info)
 
     input("Press enter to load github repository from branch name...")
 
     load_data_from_branch()
-
-    input("Press enter to load github repository from commit sha...")
-
-    load_data_from_commit()
diff --git a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml
index e60346bb3..adf90af5d 100644
--- a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml
+++ b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml
@@ -26,7 +26,7 @@ license = "MIT"
 maintainers = ["ahmetkca", "moncho", "rwood-97"]
 name = "llama-index-readers-github"
 readme = "README.md"
-version = "0.1.2"
+version = "0.1.3"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<3.12"
-- 
GitLab