From bfc61a98bcf13b74b252f727174c837533e01d3d Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Tue, 13 Aug 2024 13:01:04 +0000 Subject: [PATCH] warmup stt --- s2s_pipeline.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1fd0c48..76947e4 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -409,7 +409,7 @@ class WhisperSTTHandler(BaseHandler): def warmup(self): # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 - logger.debug(f"Warming up {self.__class__.__name__}") + logger.info(f"Warming up {self.__class__.__name__}") dummy_input = torch.randn( (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, @@ -418,12 +418,21 @@ class WhisperSTTHandler(BaseHandler): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() + if self.compile_mode: + # generating more tokens than previously will trigger CUDA graphs capture + # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs + } + start_event.record() for _ in range(n_steps): - _ = self.model.generate(dummy_input, **self.gen_kwargs) + _ = self.model.generate(dummy_input, **warmup_gen_kwargs) end_event.record() torch.cuda.synchronize() - logger.debug(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") def process(self, spoken_prompt): global pipeline_start -- GitLab