Skip to content
Snippets Groups Projects
Unverified Commit cfb2d7a5 authored by Ravi Theja's avatar Ravi Theja Committed by GitHub
Browse files

Add logging to RAFT llamapack (#12275)

* Add logging

* Add logging

* resolve errors
parent 70c16530
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,12 @@
from typing import Any, List
import random
import logging
from datasets import Dataset
# Configure logging to output to the console, with messages of level DEBUG and above
logging.basicConfig(level=logging.INFO)
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.core import SimpleDirectoryReader
......@@ -30,7 +34,6 @@ class RAFTDatasetPack(BaseLlamaPack):
num_distract_docs: int = 3,
chunk_size: int = DEFAULT_CHUNK_SIZE,
default_breakpoint_percentile_threshold=DEFAULT_BREAKPOINT_PERCENTILE_THRESHOLD,
**kwargs: Any,
):
self.file_path = file_path
self.num_questions_per_chunk = num_questions_per_chunk
......@@ -116,10 +119,7 @@ class RAFTDatasetPack(BaseLlamaPack):
Takes in a `file_path`, retrieves the document, breaks it down into chunks of size
`chunk_size`, and returns the chunks.
"""
chunks = []
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
# TODO: Should be changed to SemanticSplitterNodeParser
splitter = SemanticSplitterNodeParser(
buffer_size=1,
breakpoint_percentile_threshold=self.default_breakpoint_percentile_threshold,
......@@ -156,7 +156,7 @@ class RAFTDatasetPack(BaseLlamaPack):
datapt["type"] = "general"
datapt["question"] = q
# add 4 distractor docs
# add distractor docs
docs = [chunk]
indices = list(range(len(chunks)))
indices.remove(i)
......@@ -199,15 +199,18 @@ class RAFTDatasetPack(BaseLlamaPack):
else:
self.ds = self.ds.add_item(datapt)
def run(self, *args: Any, **kwargs: Any) -> Any:
def run(self) -> Any:
"""Run the pipeline."""
chunks = self.get_chunks(self.file_path, self.chunk_size)
logging.info(f"Number of chunks created: {len(chunks)}")
self.num_distract_docs = (
min(self.num_distract_docs, len(chunks)) - 1
) # should be less than number of chunks/ nodes created
for chunk in chunks:
for index, chunk in enumerate(chunks):
logging.info(f"Processing chunk: {index}")
self.add_chunk_to_dataset(
chunks, chunk, self.num_questions_per_chunk, self.num_distract_docs
)
......
......@@ -29,7 +29,7 @@ license = "MIT"
maintainers = ["ravi-theja"]
name = "llama-index-packs-raft-dataset"
readme = "README.md"
version = "0.1.2"
version = "0.1.3"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment