From b7c229fb7d6ae51daca3914a6fc9c3b347758e11 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Mon, 26 Aug 2024 16:00:36 +0200
Subject: [PATCH] add min new tokens

---
 LLM/language_model.py                         | 2 +-
 STT/whisper_stt_handler.py                    | 2 +-
 arguments_classes/language_model_arguments.py | 6 ++++++
 arguments_classes/whisper_stt_arguments.py    | 6 ++++++
 4 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/LLM/language_model.py b/LLM/language_model.py
index 5369c73..372d21d 100644
--- a/LLM/language_model.py
+++ b/LLM/language_model.py
@@ -75,7 +75,7 @@ class LanguageModelHandler(BaseHandler):
         dummy_input_text = "Write me a poem about Machine Learning."
         dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
         warmup_gen_kwargs = {
-            "min_new_tokens": self.gen_kwargs["max_new_tokens"],
+            "min_new_tokens": self.gen_kwargs["min_new_tokens"],
             "max_new_tokens": self.gen_kwargs["max_new_tokens"],
             **self.gen_kwargs,
         }
diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py
index 4f13a95..fc87bd3 100644
--- a/STT/whisper_stt_handler.py
+++ b/STT/whisper_stt_handler.py
@@ -71,7 +71,7 @@ class WhisperSTTHandler(BaseHandler):
             # generating more tokens than previously will trigger CUDA graphs capture
             # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
             warmup_gen_kwargs = {
-                "min_new_tokens": self.gen_kwargs["max_new_tokens"],
+                "min_new_tokens": self.gen_kwargs["min_new_tokens"],
                 "max_new_tokens": self.gen_kwargs["max_new_tokens"],
                 **self.gen_kwargs,
             }
diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/language_model_arguments.py
index cd66ca3..8680a78 100644
--- a/arguments_classes/language_model_arguments.py
+++ b/arguments_classes/language_model_arguments.py
@@ -45,6 +45,12 @@ class LanguageModelHandlerArguments:
             "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
         },
     )
+    lm_gen_min_new_tokens: int = field(
+        default=0,
+        metadata={
+            "help": "Minimum number of new tokens to generate in a single completion. Default is 0."
+        },
+    )
     lm_gen_temperature: float = field(
         default=0.0,
         metadata={
diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py
index bed382d..2edb4c2 100644
--- a/arguments_classes/whisper_stt_arguments.py
+++ b/arguments_classes/whisper_stt_arguments.py
@@ -33,6 +33,12 @@ class WhisperSTTHandlerArguments:
             "help": "The maximum number of new tokens to generate. Default is 128."
         },
     )
+    stt_gen_min_new_tokens: int = field(
+        default=0,
+        metadata={
+            "help": "The minimum number of new tokens to generate. Default is 0."
+        },
+    )
     stt_gen_num_beams: int = field(
         default=1,
         metadata={
-- 
GitLab