diff --git a/llama_index/extractors/metadata_extractors.py b/llama_index/extractors/metadata_extractors.py index ccf7413c2b0222a605950082a731b63342e88e63..6dca49f0c932bf849447332a411aae67a4134517 100644 --- a/llama_index/extractors/metadata_extractors.py +++ b/llama_index/extractors/metadata_extractors.py @@ -99,21 +99,13 @@ class TitleExtractor(BaseExtractor): return "TitleExtractor" async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]: - nodes_to_extract_title = self.filter_nodes(nodes) - - if not nodes_to_extract_title: - return [] - - nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes_to_extract_title) + nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes) titles_by_doc_id = await self.extract_titles(nodes_by_doc_id) - return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes] def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]: filtered_nodes: List[BaseNode] = [] for node in nodes: - if len(filtered_nodes) >= self.nodes: - break if self.is_text_node_only and not isinstance(node, TextNode): continue filtered_nodes.append(node) @@ -127,7 +119,8 @@ class TitleExtractor(BaseExtractor): if key not in separated_items: separated_items[key] = [] - separated_items[key].append(node) + if len(separated_items[key]) < self.nodes: + separated_items[key].append(node) return separated_items diff --git a/poetry.lock b/poetry.lock index 12acf953252e73fbcf7d6d89943b497a819f5b8e..3d0e8e727c3ce3e252b20280a4e00e90e5db55cb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4714,6 +4714,27 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pypdf" +version = "4.0.0" +description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pypdf-4.0.0-py3-none-any.whl", hash = "sha256:071d837f9d109c260cc70b026d3ffad3db6f905131b51f39fe74f2fa435ea003"}, + {file = "pypdf-4.0.0.tar.gz", hash = "sha256:637de66382238f3537fcf8f27fdb47e2c99dd5aa6de173a28a5b737f9a784973"}, +] + +[package.dependencies] +typing_extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.10\""} + +[package.extras] +crypto = ["PyCryptodome", "cryptography"] +dev = ["black", "flit", "pip-tools", "pre-commit (<2.18.0)", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"] +docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"] +full = ["Pillow (>=8.0.0)", "PyCryptodome", "cryptography"] +image = ["Pillow (>=8.0.0)"] + [[package]] name = "pyreadline3" version = "3.4.1" @@ -4911,7 +4932,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -7770,4 +7790,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ae325bfb113de60bf4a4885c752164199ad099b4f1d21035bf00701e14842a8b" +content-hash = "6d11f5b95418266365cd690eaa14e59456e23c056c64157698f67fb333bba0b3" diff --git a/pyproject.toml b/pyproject.toml index 02739b535c4495f3b2f57819330ce5f3c52b33f5..576571900a17096e635aa0fcc29826ad305630e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pymongo = "^4.5.0" # needed for tests +pypdf = "*" pytest = "7.2.1" pytest-asyncio = "0.21.0" pytest-dotenv = "0.5.2" diff --git a/tests/extractors/test_metadata_extractor.py b/tests/extractors/test_metadata_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2198ad856e94e6cd7c673c6ef8c05e4e08cf55 --- /dev/null +++ b/tests/extractors/test_metadata_extractor.py @@ -0,0 +1,68 @@ +"""Test dataset generation.""" + +import tempfile +import typing +import urllib.request + +from llama_index import SimpleDirectoryReader +from llama_index.extractors import ( + QuestionsAnsweredExtractor, + TitleExtractor, +) +from llama_index.ingestion import IngestionPipeline +from llama_index.llms import MockLLM +from llama_index.text_splitter import TokenTextSplitter + + +def two_random_integers(range_limit: int) -> typing.Tuple[int, int]: + import random + + index1 = random.randint(0, range_limit - 1) + index2 = index1 + while index2 == index1: + index2 = random.randint(0, range_limit - 1) + return index1, index2 + + +def test_metadata_extractor() -> None: + """Test metadata extraction.""" + llm = MockLLM() + + with tempfile.TemporaryDirectory() as tmpdirname: + urllib.request.urlretrieve( + "https://www.dropbox.com/scl/fi/6dlqdk6e2k1mjhi8dee5j/uber.pdf?rlkey=2jyoe49bg2vwdlz30l76czq6g&dl=1", + f"{tmpdirname}/10k-132.pdf", + ) + urllib.request.urlretrieve( + "https://www.dropbox.com/scl/fi/qn7g3vrk5mqb18ko4e5in/lyft.pdf?rlkey=j6jxtjwo8zbstdo4wz3ns8zoj&dl=1", + f"{tmpdirname}/10k-vFinal.pdf", + ) + + # Note the uninformative document file name, which may be a common scenario in a production setting + uber_docs = SimpleDirectoryReader( + input_files=[f"{tmpdirname}/10k-132.pdf"] + ).load_data() + uber_front_pages = uber_docs[0:3] + uber_content = uber_docs[63:69] + uber_docs = uber_front_pages + uber_content + + text_splitter = TokenTextSplitter( + separator=" ", chunk_size=512, chunk_overlap=128 + ) + + extractors = [ + TitleExtractor(nodes=5, llm=llm), + QuestionsAnsweredExtractor(questions=3, llm=llm), + ] + + transformations = [text_splitter, *extractors] + + pipeline = IngestionPipeline(transformations=transformations) + + uber_nodes = pipeline.run(documents=uber_docs) + + assert len(uber_nodes) == 20 + assert ( + uber_nodes[0].metadata["document_title"] + != uber_nodes[-1].metadata["document_title"] + )