Skip to content
Snippets Groups Projects
Commit bfc61a98 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

warmup stt

parent 40addb65
No related branches found
No related tags found
No related merge requests found
...@@ -409,7 +409,7 @@ class WhisperSTTHandler(BaseHandler): ...@@ -409,7 +409,7 @@ class WhisperSTTHandler(BaseHandler):
def warmup(self): def warmup(self):
# 2 warmup steps for no compile or compile mode with CUDA graphs capture # 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2 n_steps = 1 if self.compile_mode == "default" else 2
logger.debug(f"Warming up {self.__class__.__name__}") logger.info(f"Warming up {self.__class__.__name__}")
dummy_input = torch.randn( dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000), (1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype, dtype=self.torch_dtype,
...@@ -418,12 +418,21 @@ class WhisperSTTHandler(BaseHandler): ...@@ -418,12 +418,21 @@ class WhisperSTTHandler(BaseHandler):
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize() torch.cuda.synchronize()
if self.compile_mode:
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs
}
start_event.record() start_event.record()
for _ in range(n_steps): for _ in range(n_steps):
_ = self.model.generate(dummy_input, **self.gen_kwargs) _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
end_event.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
logger.debug(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
def process(self, spoken_prompt): def process(self, spoken_prompt):
global pipeline_start global pipeline_start
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment