diff --git a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py index 1b1d34d01b8d31eec0e61f601c2a1ad34b8c6db1..7fc3a7120a0fa7175e8d865815087b4122f100ea 100644 --- a/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py +++ b/llama-index-integrations/readers/llama-index-readers-github/llama_index/readers/github/repository/base.py @@ -13,6 +13,7 @@ import enum import logging import os import pathlib +import re import tempfile from typing import Any, Callable, Dict, List, Optional, Tuple @@ -385,6 +386,13 @@ class GithubRepositoryReader(BaseReader): ) return blobs_and_full_paths + def _get_base_url(self, blob_url): + match = re.match(r"(https://[^/]+\.com/)", blob_url) + if match: + return match.group(1) + else: + return "https://github.com/" + async def _generate_documents( self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]], @@ -455,7 +463,12 @@ class GithubRepositoryReader(BaseReader): + f"- adding to documents - {full_path}", ) url = os.path.join( - "https://github.com/", self._owner, self._repo, "blob/", id, full_path + self._get_base_url(blob_data.url), + self._owner, + self._repo, + "blob/", + id, + full_path, ) document = Document( text=decoded_text, diff --git a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml index 265b8c4333ee348b59da5ef786e0c8dc8b8082a2..a42630ff6549de0437841c5129f8f9f81cad1de6 100644 --- a/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-github/pyproject.toml @@ -31,7 +31,7 @@ license = "MIT" maintainers = ["ahmetkca", "moncho", "rwood-97"] name = "llama-index-readers-github" readme = "README.md" -version = "0.3.0" +version = "0.4.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/readers/llama-index-readers-github/tests/test_gh_base_url.py b/llama-index-integrations/readers/llama-index-readers-github/tests/test_gh_base_url.py new file mode 100644 index 0000000000000000000000000000000000000000..37ea20da5c3ef3912d7a5969c8560d6628f02282 --- /dev/null +++ b/llama-index-integrations/readers/llama-index-readers-github/tests/test_gh_base_url.py @@ -0,0 +1,43 @@ +import pytest +from llama_index.readers.github import GithubRepositoryReader + + +class MockGithubClient: + pass + + +@pytest.fixture() +def github_reader(): + return GithubRepositoryReader( + github_client=MockGithubClient(), owner="owner", repo="repo" + ) + + +@pytest.mark.parametrize( + ("blob_url", "expected_base_url"), + [ + ("https://github.com/owner/repo/blob/main/file.py", "https://github.com/"), + ( + "https://github-enterprise.com/owner/repo/blob/main/file.py", + "https://github-enterprise.com/", + ), + ( + "https://custom-domain.com/owner/repo/blob/main/file.py", + "https://custom-domain.com/", + ), + ( + "https://subdomain.github.com/owner/repo/blob/main/file.py", + "https://subdomain.github.com/", + ), + ( + "https://something.org/owner/repo/blob/main/file.py", + "https://github.com/", + ), + ("", "https://github.com/"), + ], +) +def test_get_base_url(github_reader, blob_url, expected_base_url): + base_url = github_reader._get_base_url(blob_url) + assert ( + base_url == expected_base_url + ), f"Expected {expected_base_url}, but got {base_url}"