From 9ed2fcf8abbda279c3600bb2deec21f07ed03065 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan <eulebihan@gmail.com> Date: Tue, 13 Aug 2024 14:54:04 +0000 Subject: [PATCH] compile parler tts --- s2s_pipeline.py | 110 ++++++++++++++++++++++++++++++++++++++++-------- utils.py | 11 +---- 2 files changed, 94 insertions(+), 27 deletions(-) diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 76947e4..274c472 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -33,6 +33,7 @@ from parler_tts import ( from utils import ( VADIterator, int2float, + next_power_of_2, ) @@ -41,7 +42,7 @@ from utils import ( CURRENT_DIR = Path(__file__).resolve().parent os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") torch._inductor.config.fx_graph_cache = True -# mind about this parameter ! should be >= 2 * number of compiled models +# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS torch._dynamo.config.cache_size_limit = 15 console = Console() @@ -418,7 +419,7 @@ class WhisperSTTHandler(BaseHandler): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() - if self.compile_mode: + if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation warmup_gen_kwargs = { @@ -587,6 +588,12 @@ class ParlerTTSHandlerArguments: "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." } ) + tts_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + } + ) gen_kwargs: dict = field( default_factory=dict, metadata={ @@ -618,11 +625,12 @@ class ParlerTTSHandler(BaseHandler): device="cuda", torch_dtype="float16", gen_kwargs={}, + compile_mode=None, description=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "She speaks very fast." ), - play_steps_s=0.5 + play_steps_s=1 ): torch_dtype = getattr(torch, torch_dtype) self._should_listen = should_listen @@ -635,33 +643,99 @@ class ParlerTTSHandler(BaseHandler): self.device = device self.torch_dtype = torch_dtype - tokenized_description = self.description_tokenizer(description, return_tensors="pt") - input_ids = tokenized_description.input_ids.to(self.device) - attention_mask = tokenized_description.attention_mask.to(self.device) - - self.gen_kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - **gen_kwargs - } + self.description = description + self.gen_kwargs = gen_kwargs framerate = self.model.audio_encoder.config.frame_rate self.play_steps = int(framerate * play_steps_s) + + framerate = self.model.audio_encoder.config.frame_rate + self.play_steps = int(framerate * play_steps_s) + + self.compile_mode = compile_mode + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) + self.warmup() + + def prepare_model_inputs( + self, + prompt, + max_length_prompt=50, + pad=False, + ): + pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {} - def process(self, lm_sentence): - console.print(f"[green]ASSISTANT: {lm_sentence}") - tokenized_prompt = self.prompt_tokenizer(lm_sentence, return_tensors="pt") + tokenized_description = self.description_tokenizer(self.description, return_tensors="pt") + input_ids = tokenized_description.input_ids.to(self.device) + attention_mask = tokenized_description.attention_mask.to(self.device) + + tokenized_prompt = self.prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt) prompt_input_ids = tokenized_prompt.input_ids.to(self.device) prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) - streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps) - tts_gen_kwargs = { + gen_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, "prompt_input_ids": prompt_input_ids, "prompt_attention_mask": prompt_attention_mask, - "streamer": streamer, **self.gen_kwargs } + return gen_kwargs + + def warmup(self): + pad_lengths = [2**i for i in range(4, 9)] + for pad_length in pad_lengths[::-1]: + model_kwargs = self.prepare_model_inputs( + "dummy prompt", + max_length_prompt=pad_length, + pad=True + ) + # 2 warmup steps for modes that capture CUDA graphs + n_steps = 1 if self.compile_mode == "default" else 2 + + if self.compile_mode not in (None, "default"): + # generating more tokens than previously will trigger CUDA graphs capture + # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation + model_kwargs = { + "min_new_tokens": 86*3, + "max_new_tokens": 86*3, + **model_kwargs + } + + logger.info(f"Warming up length {pad_length} tokens...") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(n_steps): + _ = self.model.generate(**model_kwargs) + end_event.record() + torch.cuda.synchronize() + logger.info(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) + + pad_args = {} + if self.compile_mode: + # pad to closest upper power of two + pad_length = next_power_of_2(nb_tokens) + logger.debug(f"padding to {pad_length}") + pad_args["pad"] = True + pad_args["max_length_prompt"] = pad_length + + tts_gen_kwargs = self.prepare_model_inputs( + llm_sentence, + **pad_args, + ) + streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps) + tts_gen_kwargs = { + "streamer": streamer, + **tts_gen_kwargs + } torch.manual_seed(0) thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) thread.start() diff --git a/utils.py b/utils.py index 4a3c621..6ad672c 100644 --- a/utils.py +++ b/utils.py @@ -12,16 +12,9 @@ from time import perf_counter from parler_tts import ParlerTTSForConditionalGeneration from transformers.generation.streamers import BaseStreamer -# def get_perf_counter(device): -# if device == "cpu": -# return perf_counter() - -# elif "cuda" in device: - - -# else: -# raise NotImplementedError(f"{device} not handled") +def next_power_of_2(x): + return 1 if x == 0 else 2**(x - 1).bit_length() def int2float(sound): -- GitLab