diff --git a/LLM/language_model.py b/LLM/language_model.py index 5369c7350c52f9240fe064ffac6316db0196b5ec..372d21d915392801cd769798a6e3b265978134e9 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -75,7 +75,7 @@ class LanguageModelHandler(BaseHandler): dummy_input_text = "Write me a poem about Machine Learning." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 4f13a95eb57fa12f31fcb117e76512c7cb728742..fc87bd3a31c26ab988de658acdb64e4308c77cbf 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -71,7 +71,7 @@ class WhisperSTTHandler(BaseHandler): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/arguments_classes/language_model_arguments.py b/arguments_classes/language_model_arguments.py index cd66ca3795079a2bd033b22d5355b09c6883530e..8680a78fe07ec2d7cb5f6e1a3dcbb0ed3319ac6b 100644 --- a/arguments_classes/language_model_arguments.py +++ b/arguments_classes/language_model_arguments.py @@ -45,6 +45,12 @@ class LanguageModelHandlerArguments: "help": "Maximum number of new tokens to generate in a single completion. Default is 128." }, ) + lm_gen_min_new_tokens: int = field( + default=0, + metadata={ + "help": "Minimum number of new tokens to generate in a single completion. Default is 0." + }, + ) lm_gen_temperature: float = field( default=0.0, metadata={ diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index bed382dda754da36965b4d86e68a7f8b4d9c322c..2edb4c24e7d75bd8a204edf1a82a8dd79b0df457 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -33,6 +33,12 @@ class WhisperSTTHandlerArguments: "help": "The maximum number of new tokens to generate. Default is 128." }, ) + stt_gen_min_new_tokens: int = field( + default=0, + metadata={ + "help": "The minimum number of new tokens to generate. Default is 0." + }, + ) stt_gen_num_beams: int = field( default=1, metadata={