diff --git a/llama-index-core/llama_index/core/llama_dataset/generator.py b/llama-index-core/llama_index/core/llama_dataset/generator.py index 065653485f58ddf2c15b1fc16e8a3e9639063380..808ed7eb00a4f019e6d7d9b9c30c6388cd50f237 100644 --- a/llama-index-core/llama_index/core/llama_dataset/generator.py +++ b/llama-index-core/llama_index/core/llama_dataset/generator.py @@ -2,6 +2,7 @@ from __future__ import annotations import re +import warnings from typing import List, Optional from llama_index.core import Document, ServiceContext, SummaryIndex @@ -78,6 +79,7 @@ class RagDatasetGenerator(PromptMixin): ) -> None: """Init params.""" self._llm = llm or llm_from_settings_or_context(Settings, service_context) + self.num_questions_per_chunk = num_questions_per_chunk self.text_question_template = text_question_template or PromptTemplate( DEFAULT_QUESTION_GENERATION_PROMPT ) @@ -184,7 +186,15 @@ class RagDatasetGenerator(PromptMixin): ] cleaned_questions = [ question for question in cleaned_questions if len(question) > 0 - ] + ][: self.num_questions_per_chunk] + + num_questions_generated = len(cleaned_questions) + if num_questions_generated < self.num_questions_per_chunk: + warnings.warn( + f"Fewer questions generated ({num_questions_generated}) " + f"than requested ({self.num_questions_per_chunk})." + ) + index = summary_indices[idx] reference_context = nodes[idx].text model_name = self._llm.metadata.model_name diff --git a/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py b/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py index 41ad4e61f24708b29bf9ad0230213f29af76c0ef..ad3922471a4047470067bb938d6e73c2b39724e1 100644 --- a/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py +++ b/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py @@ -2,6 +2,7 @@ import json import re import uuid +import warnings from typing import Dict, List, Tuple from llama_index.core.bridge.pydantic import BaseModel @@ -89,7 +90,16 @@ def generate_qa_embedding_pairs( questions = [ re.sub(r"^\d+[\).\s]", "", question).strip() for question in result ] - questions = [question for question in questions if len(question) > 0] + questions = [question for question in questions if len(question) > 0][ + :num_questions_per_chunk + ] + + num_questions_generated = len(questions) + if num_questions_generated < num_questions_per_chunk: + warnings.warn( + f"Fewer questions generated ({num_questions_generated}) " + f"than requested ({num_questions_per_chunk})." + ) for question in questions: question_id = str(uuid.uuid4()) diff --git a/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py b/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py index e4ad7d1d830aca396e4bf09806130b654f5a25aa..3c5205c82f0811bc7f2ccc69ce1422fd81bd441b 100644 --- a/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py +++ b/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py @@ -1,5 +1,6 @@ """Dataset Generator for Cross Encoder Finetuning.""" import re +import warnings from dataclasses import dataclass from typing import List, Optional @@ -72,6 +73,15 @@ def generate_synthetic_queries_over_documents( response.message.content if response.message.content is not None else "" ) response_questions = re.split(";|\n", response_content) + response_questions = response_questions[:num_questions_per_chunk] + + num_questions_generated = len(response_questions) + if num_questions_generated < num_questions_per_chunk: + warnings.warn( + f"Fewer questions generated ({num_questions_generated}) " + f"than requested ({num_questions_per_chunk})." + ) + questions.extend(response_questions) return questions diff --git a/llama-index-finetuning/llama_index/finetuning/embeddings/common.py b/llama-index-finetuning/llama_index/finetuning/embeddings/common.py index 03bad691661d3b11c3ade58fb37cdfc2a610f073..381dd1fd7152a651030382142593f43573e55738 100644 --- a/llama-index-finetuning/llama_index/finetuning/embeddings/common.py +++ b/llama-index-finetuning/llama_index/finetuning/embeddings/common.py @@ -2,6 +2,7 @@ import json import re import uuid +import warnings from typing import Dict, List, Tuple from llama_index.core.bridge.pydantic import BaseModel @@ -89,7 +90,16 @@ def generate_qa_embedding_pairs( questions = [ re.sub(r"^\d+[\).\s]", "", question).strip() for question in result ] - questions = [question for question in questions if len(question) > 0] + questions = [question for question in questions if len(question) > 0][ + :num_questions_per_chunk + ] + + num_questions_generated = len(questions) + if num_questions_generated < num_questions_per_chunk: + warnings.warn( + f"Fewer questions generated ({num_questions_generated}) " + f"than requested ({num_questions_per_chunk})." + ) for question in questions: question_id = str(uuid.uuid4()) diff --git a/llama-index-finetuning/pyproject.toml b/llama-index-finetuning/pyproject.toml index e2de1fd134ad80697fb350d17bed4b760a3356ca..c6d2087173548930e914a508177b511388f1eea3 100644 --- a/llama-index-finetuning/pyproject.toml +++ b/llama-index-finetuning/pyproject.toml @@ -25,7 +25,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-finetuning" readme = "README.md" -version = "0.1.5" +version = "0.1.6" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py b/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py index eef98a8dbca23507891eb82e7e0915de7f34f3a7..007c2c2d59db47e0b162cc1f7acf9cee388f31a4 100644 --- a/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py +++ b/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py @@ -5,6 +5,8 @@ from typing import Any, List import random import logging +import warnings + from datasets import Dataset # Configure logging to output to the console, with messages of level DEBUG and above @@ -108,8 +110,17 @@ class RAFTDatasetPack(BaseLlamaPack): ] queries = str(self.llm.chat(messages)).split("\n") - queries = [self.strip_str(q) for q in queries] - return [q for q in queries if any(c.isalpha() for c in q)] + questions = [self.strip_str(q) for q in queries] + questions = [q for q in questions if any(c.isalpha() for c in q)][:x] + + num_questions_generated = len(questions) + if num_questions_generated < x: + warnings.warn( + f"Fewer questions generated ({num_questions_generated}) " + f"than requested ({x})." + ) + + return questions def get_chunks(self, file_path: str, chunk_size: int) -> List[str]: """ diff --git a/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml b/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml index b559edd63f34fdcec9516caf089e9a9dd7ce0b72..d1a0e2823fb836732f5db74b8d7fe68e980f856d 100644 --- a/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml +++ b/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml @@ -29,7 +29,7 @@ license = "MIT" maintainers = ["ravi-theja"] name = "llama-index-packs-raft-dataset" readme = "README.md" -version = "0.1.5" +version = "0.1.6" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"