From baa3e82e56a647d0281135c8c279fa1c386e8f6c Mon Sep 17 00:00:00 2001
From: Tuomas Tikkanen <tuomas.tikkanen@hotmail.com>
Date: Mon, 20 May 2024 23:04:52 +0300
Subject: [PATCH] Fix: Limit the number of generated questions (#13596)

---
 .../llama_index/core/llama_dataset/generator.py   | 12 +++++++++++-
 .../core/llama_dataset/legacy/embedding.py        | 12 +++++++++++-
 .../finetuning/cross_encoders/dataset_gen.py      | 10 ++++++++++
 .../llama_index/finetuning/embeddings/common.py   | 12 +++++++++++-
 llama-index-finetuning/pyproject.toml             |  2 +-
 .../llama_index/packs/raft_dataset/base.py        | 15 +++++++++++++--
 .../llama-index-packs-raft-dataset/pyproject.toml |  2 +-
 7 files changed, 58 insertions(+), 7 deletions(-)

diff --git a/llama-index-core/llama_index/core/llama_dataset/generator.py b/llama-index-core/llama_index/core/llama_dataset/generator.py
index 065653485f..808ed7eb00 100644
--- a/llama-index-core/llama_index/core/llama_dataset/generator.py
+++ b/llama-index-core/llama_index/core/llama_dataset/generator.py
@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import re
+import warnings
 from typing import List, Optional
 
 from llama_index.core import Document, ServiceContext, SummaryIndex
@@ -78,6 +79,7 @@ class RagDatasetGenerator(PromptMixin):
     ) -> None:
         """Init params."""
         self._llm = llm or llm_from_settings_or_context(Settings, service_context)
+        self.num_questions_per_chunk = num_questions_per_chunk
         self.text_question_template = text_question_template or PromptTemplate(
             DEFAULT_QUESTION_GENERATION_PROMPT
         )
@@ -184,7 +186,15 @@ class RagDatasetGenerator(PromptMixin):
             ]
             cleaned_questions = [
                 question for question in cleaned_questions if len(question) > 0
-            ]
+            ][: self.num_questions_per_chunk]
+
+            num_questions_generated = len(cleaned_questions)
+            if num_questions_generated < self.num_questions_per_chunk:
+                warnings.warn(
+                    f"Fewer questions generated ({num_questions_generated}) "
+                    f"than requested ({self.num_questions_per_chunk})."
+                )
+
             index = summary_indices[idx]
             reference_context = nodes[idx].text
             model_name = self._llm.metadata.model_name
diff --git a/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py b/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py
index 41ad4e61f2..ad3922471a 100644
--- a/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py
+++ b/llama-index-core/llama_index/core/llama_dataset/legacy/embedding.py
@@ -2,6 +2,7 @@
 import json
 import re
 import uuid
+import warnings
 from typing import Dict, List, Tuple
 
 from llama_index.core.bridge.pydantic import BaseModel
@@ -89,7 +90,16 @@ def generate_qa_embedding_pairs(
         questions = [
             re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
         ]
-        questions = [question for question in questions if len(question) > 0]
+        questions = [question for question in questions if len(question) > 0][
+            :num_questions_per_chunk
+        ]
+
+        num_questions_generated = len(questions)
+        if num_questions_generated < num_questions_per_chunk:
+            warnings.warn(
+                f"Fewer questions generated ({num_questions_generated}) "
+                f"than requested ({num_questions_per_chunk})."
+            )
 
         for question in questions:
             question_id = str(uuid.uuid4())
diff --git a/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py b/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py
index e4ad7d1d83..3c5205c82f 100644
--- a/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py
+++ b/llama-index-finetuning/llama_index/finetuning/cross_encoders/dataset_gen.py
@@ -1,5 +1,6 @@
 """Dataset Generator for Cross Encoder Finetuning."""
 import re
+import warnings
 from dataclasses import dataclass
 from typing import List, Optional
 
@@ -72,6 +73,15 @@ def generate_synthetic_queries_over_documents(
             response.message.content if response.message.content is not None else ""
         )
         response_questions = re.split(";|\n", response_content)
+        response_questions = response_questions[:num_questions_per_chunk]
+
+        num_questions_generated = len(response_questions)
+        if num_questions_generated < num_questions_per_chunk:
+            warnings.warn(
+                f"Fewer questions generated ({num_questions_generated}) "
+                f"than requested ({num_questions_per_chunk})."
+            )
+
         questions.extend(response_questions)
 
     return questions
diff --git a/llama-index-finetuning/llama_index/finetuning/embeddings/common.py b/llama-index-finetuning/llama_index/finetuning/embeddings/common.py
index 03bad69166..381dd1fd71 100644
--- a/llama-index-finetuning/llama_index/finetuning/embeddings/common.py
+++ b/llama-index-finetuning/llama_index/finetuning/embeddings/common.py
@@ -2,6 +2,7 @@
 import json
 import re
 import uuid
+import warnings
 from typing import Dict, List, Tuple
 
 from llama_index.core.bridge.pydantic import BaseModel
@@ -89,7 +90,16 @@ def generate_qa_embedding_pairs(
         questions = [
             re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
         ]
-        questions = [question for question in questions if len(question) > 0]
+        questions = [question for question in questions if len(question) > 0][
+            :num_questions_per_chunk
+        ]
+
+        num_questions_generated = len(questions)
+        if num_questions_generated < num_questions_per_chunk:
+            warnings.warn(
+                f"Fewer questions generated ({num_questions_generated}) "
+                f"than requested ({num_questions_per_chunk})."
+            )
 
         for question in questions:
             question_id = str(uuid.uuid4())
diff --git a/llama-index-finetuning/pyproject.toml b/llama-index-finetuning/pyproject.toml
index e2de1fd134..c6d2087173 100644
--- a/llama-index-finetuning/pyproject.toml
+++ b/llama-index-finetuning/pyproject.toml
@@ -25,7 +25,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-finetuning"
 readme = "README.md"
-version = "0.1.5"
+version = "0.1.6"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py b/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py
index eef98a8dbc..007c2c2d59 100644
--- a/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py
+++ b/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py
@@ -5,6 +5,8 @@
 from typing import Any, List
 import random
 import logging
+import warnings
+
 from datasets import Dataset
 
 # Configure logging to output to the console, with messages of level DEBUG and above
@@ -108,8 +110,17 @@ class RAFTDatasetPack(BaseLlamaPack):
         ]
 
         queries = str(self.llm.chat(messages)).split("\n")
-        queries = [self.strip_str(q) for q in queries]
-        return [q for q in queries if any(c.isalpha() for c in q)]
+        questions = [self.strip_str(q) for q in queries]
+        questions = [q for q in questions if any(c.isalpha() for c in q)][:x]
+
+        num_questions_generated = len(questions)
+        if num_questions_generated < x:
+            warnings.warn(
+                f"Fewer questions generated ({num_questions_generated}) "
+                f"than requested ({x})."
+            )
+
+        return questions
 
     def get_chunks(self, file_path: str, chunk_size: int) -> List[str]:
         """
diff --git a/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml b/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml
index b559edd63f..d1a0e2823f 100644
--- a/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml
+++ b/llama-index-packs/llama-index-packs-raft-dataset/pyproject.toml
@@ -29,7 +29,7 @@ license = "MIT"
 maintainers = ["ravi-theja"]
 name = "llama-index-packs-raft-dataset"
 readme = "README.md"
-version = "0.1.5"
+version = "0.1.6"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
-- 
GitLab