Skip to content
Snippets Groups Projects
Unverified Commit b75d9d3d authored by Sasha Dog's avatar Sasha Dog Committed by GitHub
Browse files

fixed the bug introduced by pull request 10209 (#10226)

parent 5fc696a2
Branches
Tags
No related merge requests found
...@@ -99,21 +99,13 @@ class TitleExtractor(BaseExtractor): ...@@ -99,21 +99,13 @@ class TitleExtractor(BaseExtractor):
return "TitleExtractor" return "TitleExtractor"
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]: async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
nodes_to_extract_title = self.filter_nodes(nodes) nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes)
if not nodes_to_extract_title:
return []
nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes_to_extract_title)
titles_by_doc_id = await self.extract_titles(nodes_by_doc_id) titles_by_doc_id = await self.extract_titles(nodes_by_doc_id)
return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes] return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes]
def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]: def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]:
filtered_nodes: List[BaseNode] = [] filtered_nodes: List[BaseNode] = []
for node in nodes: for node in nodes:
if len(filtered_nodes) >= self.nodes:
break
if self.is_text_node_only and not isinstance(node, TextNode): if self.is_text_node_only and not isinstance(node, TextNode):
continue continue
filtered_nodes.append(node) filtered_nodes.append(node)
...@@ -127,7 +119,8 @@ class TitleExtractor(BaseExtractor): ...@@ -127,7 +119,8 @@ class TitleExtractor(BaseExtractor):
if key not in separated_items: if key not in separated_items:
separated_items[key] = [] separated_items[key] = []
separated_items[key].append(node) if len(separated_items[key]) < self.nodes:
separated_items[key].append(node)
return separated_items return separated_items
......
...@@ -4714,6 +4714,27 @@ files = [ ...@@ -4714,6 +4714,27 @@ files = [
[package.extras] [package.extras]
diagrams = ["jinja2", "railroad-diagrams"] diagrams = ["jinja2", "railroad-diagrams"]
   
[[package]]
name = "pypdf"
version = "4.0.0"
description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files"
optional = false
python-versions = ">=3.6"
files = [
{file = "pypdf-4.0.0-py3-none-any.whl", hash = "sha256:071d837f9d109c260cc70b026d3ffad3db6f905131b51f39fe74f2fa435ea003"},
{file = "pypdf-4.0.0.tar.gz", hash = "sha256:637de66382238f3537fcf8f27fdb47e2c99dd5aa6de173a28a5b737f9a784973"},
]
[package.dependencies]
typing_extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.10\""}
[package.extras]
crypto = ["PyCryptodome", "cryptography"]
dev = ["black", "flit", "pip-tools", "pre-commit (<2.18.0)", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"]
docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"]
full = ["Pillow (>=8.0.0)", "PyCryptodome", "cryptography"]
image = ["Pillow (>=8.0.0)"]
[[package]] [[package]]
name = "pyreadline3" name = "pyreadline3"
version = "3.4.1" version = "3.4.1"
...@@ -4911,7 +4932,6 @@ files = [ ...@@ -4911,7 +4932,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
...@@ -7770,4 +7790,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc ...@@ -7770,4 +7790,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "ae325bfb113de60bf4a4885c752164199ad099b4f1d21035bf00701e14842a8b" content-hash = "6d11f5b95418266365cd690eaa14e59456e23c056c64157698f67fb333bba0b3"
...@@ -121,6 +121,7 @@ mypy = "0.991" ...@@ -121,6 +121,7 @@ mypy = "0.991"
pre-commit = "3.2.0" pre-commit = "3.2.0"
pylint = "2.15.10" pylint = "2.15.10"
pymongo = "^4.5.0" # needed for tests pymongo = "^4.5.0" # needed for tests
pypdf = "*"
pytest = "7.2.1" pytest = "7.2.1"
pytest-asyncio = "0.21.0" pytest-asyncio = "0.21.0"
pytest-dotenv = "0.5.2" pytest-dotenv = "0.5.2"
......
"""Test dataset generation."""
import tempfile
import typing
import urllib.request
from llama_index import SimpleDirectoryReader
from llama_index.extractors import (
QuestionsAnsweredExtractor,
TitleExtractor,
)
from llama_index.ingestion import IngestionPipeline
from llama_index.llms import MockLLM
from llama_index.text_splitter import TokenTextSplitter
def two_random_integers(range_limit: int) -> typing.Tuple[int, int]:
import random
index1 = random.randint(0, range_limit - 1)
index2 = index1
while index2 == index1:
index2 = random.randint(0, range_limit - 1)
return index1, index2
def test_metadata_extractor() -> None:
"""Test metadata extraction."""
llm = MockLLM()
with tempfile.TemporaryDirectory() as tmpdirname:
urllib.request.urlretrieve(
"https://www.dropbox.com/scl/fi/6dlqdk6e2k1mjhi8dee5j/uber.pdf?rlkey=2jyoe49bg2vwdlz30l76czq6g&dl=1",
f"{tmpdirname}/10k-132.pdf",
)
urllib.request.urlretrieve(
"https://www.dropbox.com/scl/fi/qn7g3vrk5mqb18ko4e5in/lyft.pdf?rlkey=j6jxtjwo8zbstdo4wz3ns8zoj&dl=1",
f"{tmpdirname}/10k-vFinal.pdf",
)
# Note the uninformative document file name, which may be a common scenario in a production setting
uber_docs = SimpleDirectoryReader(
input_files=[f"{tmpdirname}/10k-132.pdf"]
).load_data()
uber_front_pages = uber_docs[0:3]
uber_content = uber_docs[63:69]
uber_docs = uber_front_pages + uber_content
text_splitter = TokenTextSplitter(
separator=" ", chunk_size=512, chunk_overlap=128
)
extractors = [
TitleExtractor(nodes=5, llm=llm),
QuestionsAnsweredExtractor(questions=3, llm=llm),
]
transformations = [text_splitter, *extractors]
pipeline = IngestionPipeline(transformations=transformations)
uber_nodes = pipeline.run(documents=uber_docs)
assert len(uber_nodes) == 20
assert (
uber_nodes[0].metadata["document_title"]
!= uber_nodes[-1].metadata["document_title"]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment