diff --git a/llama_index/llama_dataset/generator.py b/llama_index/llama_dataset/generator.py index 2e3aeb84d64c7098994358f088acde20b843be9f..6adb38213d89062f3ebc250b2e48def91ae9699b 100644 --- a/llama_index/llama_dataset/generator.py +++ b/llama_index/llama_dataset/generator.py @@ -6,6 +6,7 @@ import re from typing import List from llama_index import Document, ServiceContext, SummaryIndex +from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs from llama_index.ingestion import run_transformations from llama_index.llama_dataset import ( CreatedBy, @@ -57,6 +58,7 @@ class RagDatasetGenerator(PromptMixin): question_gen_query: str | None = None, metadata_mode: MetadataMode = MetadataMode.NONE, show_progress: bool = False, + workers: int = DEFAULT_NUM_WORKERS, ) -> None: """Init params.""" if service_context is None: @@ -79,6 +81,7 @@ class RagDatasetGenerator(PromptMixin): self.nodes = nodes self._metadata_mode = metadata_mode self._show_progress = show_progress + self._workers = workers @classmethod def from_documents( @@ -92,6 +95,7 @@ class RagDatasetGenerator(PromptMixin): required_keywords: List[str] | None = None, exclude_keywords: List[str] | None = None, show_progress: bool = False, + workers: int = DEFAULT_NUM_WORKERS, ) -> RagDatasetGenerator: """Generate dataset from documents.""" if service_context is None: @@ -123,6 +127,7 @@ class RagDatasetGenerator(PromptMixin): text_qa_template=text_qa_template, question_gen_query=question_gen_query, show_progress=show_progress, + workers=workers, ) async def _agenerate_dataset( @@ -133,14 +138,6 @@ class RagDatasetGenerator(PromptMixin): """Node question generator.""" query_tasks = [] examples: List[LabelledRagDataExample] = [] - - if self._show_progress: - from tqdm.asyncio import tqdm_asyncio - - async_module = tqdm_asyncio - else: - async_module = asyncio - summary_indices: List[SummaryIndex] = [] for node in nodes: index = SummaryIndex.from_documents( @@ -164,7 +161,7 @@ class RagDatasetGenerator(PromptMixin): query_tasks.append(task) summary_indices.append(index) - responses = await async_module.gather(*query_tasks) # result order is preserved + responses = await run_jobs(query_tasks, self._show_progress, self._workers) for idx, response in enumerate(responses): result = str(response).strip().split("\n") cleaned_questions = [ @@ -187,9 +184,9 @@ class RagDatasetGenerator(PromptMixin): ) qr_task = qa_query_engine.aquery(query) qr_tasks.append(qr_task) - answer_responses: List[RESPONSE_TYPE] = await async_module.gather( - *qr_tasks - ) # execution order is not guaranteed but result values order is preserved + answer_responses: List[RESPONSE_TYPE] = await run_jobs( + qr_tasks, self._show_progress, self._workers + ) for question, answer_response in zip( cleaned_questions, answer_responses ):