From 0574f84029d4f0d6fa131bb5899100098567db33 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Tue, 13 Aug 2024 16:06:55 +0000 Subject: [PATCH] compile parler tts --- s2s_pipeline.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 5c81b59..eebb2c6 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -301,7 +301,7 @@ class VADHandler(BaseHandler): audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) audio_float32 = int2float(audio_int16) vad_output = self.iterator(torch.from_numpy(audio_float32)) - if vad_output is not None: + if vad_output is not None and len(vad_output) != 0: logger.debug("VAD: end of speech detected") array = torch.cat(vad_output).cpu().numpy() duration_ms = len(array) / self._sample_rate * 1000 @@ -622,11 +622,13 @@ class ParlerTTSHandlerArguments: "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" } ) - gen_kwargs: dict = field( - default_factory=dict, - metadata={ - "help": "Additional keyword arguments to pass to the model's generate method. Use this to customize generation settings." - } + tts_gen_min_new_tokens: int = field( + default=10, + metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"} + ) + tts_gen_max_new_tokens: int = field( + default=512, + metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"} ) description: str = field( default=( @@ -643,6 +645,12 @@ class ParlerTTSHandlerArguments: "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." } ) + max_prompt_pad_length: int = field( + default=8, + metadata={ + "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." + } + ) class ParlerTTSHandler(BaseHandler): @@ -652,6 +660,7 @@ class ParlerTTSHandler(BaseHandler): model_name="ylacombe/parler-tts-mini-jenny-30H", device="cuda", torch_dtype="float16", + max_prompt_pad_length=8, gen_kwargs={}, compile_mode=None, description=( @@ -660,6 +669,7 @@ class ParlerTTSHandler(BaseHandler): ), play_steps_s=1 ): + self.max_prompt_pad_length = max_prompt_pad_length torch_dtype = getattr(torch, torch_dtype) self._should_listen = should_listen self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -681,6 +691,10 @@ class ParlerTTSHandler(BaseHandler): self.play_steps = int(framerate * play_steps_s) self.compile_mode = compile_mode + if self.compile_mode not in (None, "default"): + logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'") + self.compile_mode = "default" + if self.compile_mode: self.model.generation_config.cache_implementation = "static" self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) @@ -712,7 +726,7 @@ class ParlerTTSHandler(BaseHandler): return gen_kwargs def warmup(self): - pad_lengths = [2**i for i in range(4, 9)] + pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] for pad_length in pad_lengths[::-1]: model_kwargs = self.prepare_model_inputs( "dummy prompt", @@ -722,15 +736,6 @@ class ParlerTTSHandler(BaseHandler): # 2 warmup steps for modes that capture CUDA graphs n_steps = 1 if self.compile_mode == "default" else 2 - if self.compile_mode not in (None, "default"): - # generating more tokens than previously will trigger CUDA graphs capture - # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation - model_kwargs = { - "min_new_tokens": 86*3, - "max_new_tokens": 86*3, - **model_kwargs - } - logger.info(f"Warming up length {pad_length} tokens...") start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) -- GitLab