diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1fd0c48ec814a7af51d81cda8c554a593360ad84..76947e46211ff614dd054ebc238d660421cc0c2f 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -409,7 +409,7 @@ class WhisperSTTHandler(BaseHandler): def warmup(self): # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 - logger.debug(f"Warming up {self.__class__.__name__}") + logger.info(f"Warming up {self.__class__.__name__}") dummy_input = torch.randn( (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, @@ -418,12 +418,21 @@ class WhisperSTTHandler(BaseHandler): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() + if self.compile_mode: + # generating more tokens than previously will trigger CUDA graphs capture + # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs + } + start_event.record() for _ in range(n_steps): - _ = self.model.generate(dummy_input, **self.gen_kwargs) + _ = self.model.generate(dummy_input, **warmup_gen_kwargs) end_event.record() torch.cuda.synchronize() - logger.debug(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") def process(self, spoken_prompt): global pipeline_start