From b7c229fb7d6ae51daca3914a6fc9c3b347758e11 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Mon, 26 Aug 2024 16:00:36 +0200 Subject: [PATCH] add min new tokens --- LLM/language_model.py | 2 +- STT/whisper_stt_handler.py | 2 +- arguments_classes/language_model_arguments.py | 6 ++++++ arguments_classes/whisper_stt_arguments.py | 6 ++++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/LLM/language_model.py b/LLM/language_model.py index 5369c73..372d21d 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -75,7 +75,7 @@ class LanguageModelHandler(BaseHandler): dummy_input_text = "Write me a poem about Machine Learning." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 4f13a95..fc87bd3 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -71,7 +71,7 @@ class WhisperSTTHandler(BaseHandler): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/language_model_arguments.py index cd66ca3..8680a78 100644 --- a/arguments_classes/language_model_arguments.py +++ b/arguments_classes/language_model_arguments.py @@ -45,6 +45,12 @@ class LanguageModelHandlerArguments: "help": "Maximum number of new tokens to generate in a single completion. Default is 128." }, ) + lm_gen_min_new_tokens: int = field( + default=0, + metadata={ + "help": "Minimum number of new tokens to generate in a single completion. Default is 0." + }, + ) lm_gen_temperature: float = field( default=0.0, metadata={ diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index bed382d..2edb4c2 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -33,6 +33,12 @@ class WhisperSTTHandlerArguments: "help": "The maximum number of new tokens to generate. Default is 128." }, ) + stt_gen_min_new_tokens: int = field( + default=0, + metadata={ + "help": "The minimum number of new tokens to generate. Default is 0." + }, + ) stt_gen_num_beams: int = field( default=1, metadata={ -- GitLab