Skip to content
Snippets Groups Projects
Commit b7c229fb authored by Andres Marafioti's avatar Andres Marafioti
Browse files

add min new tokens

parent d3d25c45
No related branches found
No related tags found
No related merge requests found
...@@ -75,7 +75,7 @@ class LanguageModelHandler(BaseHandler): ...@@ -75,7 +75,7 @@ class LanguageModelHandler(BaseHandler):
dummy_input_text = "Write me a poem about Machine Learning." dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = { warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"], "min_new_tokens": self.gen_kwargs["min_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs, **self.gen_kwargs,
} }
......
...@@ -71,7 +71,7 @@ class WhisperSTTHandler(BaseHandler): ...@@ -71,7 +71,7 @@ class WhisperSTTHandler(BaseHandler):
# generating more tokens than previously will trigger CUDA graphs capture # generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = { warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"], "min_new_tokens": self.gen_kwargs["min_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs, **self.gen_kwargs,
} }
......
...@@ -45,6 +45,12 @@ class LanguageModelHandlerArguments: ...@@ -45,6 +45,12 @@ class LanguageModelHandlerArguments:
"help": "Maximum number of new tokens to generate in a single completion. Default is 128." "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
}, },
) )
lm_gen_min_new_tokens: int = field(
default=0,
metadata={
"help": "Minimum number of new tokens to generate in a single completion. Default is 0."
},
)
lm_gen_temperature: float = field( lm_gen_temperature: float = field(
default=0.0, default=0.0,
metadata={ metadata={
......
...@@ -33,6 +33,12 @@ class WhisperSTTHandlerArguments: ...@@ -33,6 +33,12 @@ class WhisperSTTHandlerArguments:
"help": "The maximum number of new tokens to generate. Default is 128." "help": "The maximum number of new tokens to generate. Default is 128."
}, },
) )
stt_gen_min_new_tokens: int = field(
default=0,
metadata={
"help": "The minimum number of new tokens to generate. Default is 0."
},
)
stt_gen_num_beams: int = field( stt_gen_num_beams: int = field(
default=1, default=1,
metadata={ metadata={
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment