From 9ed2fcf8abbda279c3600bb2deec21f07ed03065 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Tue, 13 Aug 2024 14:54:04 +0000
Subject: [PATCH] compile parler tts

---
 s2s_pipeline.py | 110 ++++++++++++++++++++++++++++++++++++++++--------
 utils.py        |  11 +----
 2 files changed, 94 insertions(+), 27 deletions(-)

diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 76947e4..274c472 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -33,6 +33,7 @@ from parler_tts import (
 from utils import (
     VADIterator, 
     int2float,
+    next_power_of_2,
 )
 
 
@@ -41,7 +42,7 @@ from utils import (
 CURRENT_DIR = Path(__file__).resolve().parent
 os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") 
 torch._inductor.config.fx_graph_cache = True
-# mind about this parameter ! should be >= 2 * number of compiled models
+# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
 torch._dynamo.config.cache_size_limit = 15
 
 console = Console()
@@ -418,7 +419,7 @@ class WhisperSTTHandler(BaseHandler):
         start_event = torch.cuda.Event(enable_timing=True)
         end_event = torch.cuda.Event(enable_timing=True)
         torch.cuda.synchronize()
-        if self.compile_mode:
+        if self.compile_mode not in (None, "default"):
             # generating more tokens than previously will trigger CUDA graphs capture
             # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
             warmup_gen_kwargs = {
@@ -587,6 +588,12 @@ class ParlerTTSHandlerArguments:
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
         }
     )
+    tts_compile_mode: str = field(
+        default=None,
+        metadata={
+            "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
+        }
+    )
     gen_kwargs: dict = field(
         default_factory=dict,
         metadata={
@@ -618,11 +625,12 @@ class ParlerTTSHandler(BaseHandler):
             device="cuda", 
             torch_dtype="float16",
             gen_kwargs={},
+            compile_mode=None,
             description=(
                 "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
                 "She speaks very fast."
             ),
-            play_steps_s=0.5
+            play_steps_s=1
         ):
         torch_dtype = getattr(torch, torch_dtype)
         self._should_listen = should_listen
@@ -635,33 +643,99 @@ class ParlerTTSHandler(BaseHandler):
         self.device = device
         self.torch_dtype = torch_dtype
 
-        tokenized_description = self.description_tokenizer(description, return_tensors="pt")
-        input_ids = tokenized_description.input_ids.to(self.device)
-        attention_mask = tokenized_description.attention_mask.to(self.device)
-
-        self.gen_kwargs = {
-            "input_ids": input_ids,
-            "attention_mask": attention_mask,
-            **gen_kwargs
-        }
+        self.description = description
+        self.gen_kwargs = gen_kwargs
         
         framerate = self.model.audio_encoder.config.frame_rate
         self.play_steps = int(framerate * play_steps_s)
+        
+        framerate = self.model.audio_encoder.config.frame_rate
+        self.play_steps = int(framerate * play_steps_s)
+
+        self.compile_mode = compile_mode
+        if self.compile_mode:
+            self.model.generation_config.cache_implementation = "static"
+            self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
+        self.warmup()
+
+    def prepare_model_inputs(
+        self,
+        prompt,
+        max_length_prompt=50,
+        pad=False,
+    ):
+        pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
 
-    def process(self, lm_sentence):
-        console.print(f"[green]ASSISTANT: {lm_sentence}")
-        tokenized_prompt = self.prompt_tokenizer(lm_sentence, return_tensors="pt")
+        tokenized_description = self.description_tokenizer(self.description, return_tensors="pt")
+        input_ids = tokenized_description.input_ids.to(self.device)
+        attention_mask = tokenized_description.attention_mask.to(self.device)
+
+        tokenized_prompt = self.prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt)
         prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
         prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
 
-        streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps)
-        tts_gen_kwargs = {
+        gen_kwargs = {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
             "prompt_input_ids": prompt_input_ids,
             "prompt_attention_mask": prompt_attention_mask,
-            "streamer": streamer,
             **self.gen_kwargs
         }
+        return gen_kwargs
+    
+    def warmup(self):
+        pad_lengths = [2**i for i in range(4, 9)]
+        for pad_length in pad_lengths[::-1]:
+            model_kwargs = self.prepare_model_inputs(
+                "dummy prompt", 
+                max_length_prompt=pad_length,
+                pad=True
+            )
+            # 2 warmup steps for modes that capture CUDA graphs
+            n_steps = 1 if self.compile_mode == "default" else 2
+
+            if self.compile_mode not in (None, "default"):
+                # generating more tokens than previously will trigger CUDA graphs capture
+                # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
+                model_kwargs = {
+                    "min_new_tokens": 86*3,
+                    "max_new_tokens": 86*3,
+                    **model_kwargs
+                }
+
+            logger.info(f"Warming up length {pad_length} tokens...")
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            torch.cuda.synchronize()
+            start_event.record()
+            for _ in range(n_steps):
+                _ = self.model.generate(**model_kwargs)
+            end_event.record()
+            torch.cuda.synchronize()
+            logger.info(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+
+    def process(self, llm_sentence):
+        console.print(f"[green]ASSISTANT: {llm_sentence}")
+        nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
+
+        pad_args = {}
+        if self.compile_mode:
+            # pad to closest upper power of two
+            pad_length = next_power_of_2(nb_tokens)
+            logger.debug(f"padding to {pad_length}")
+            pad_args["pad"] = True
+            pad_args["max_length_prompt"] = pad_length
+    
+        tts_gen_kwargs = self.prepare_model_inputs(
+            llm_sentence,
+            **pad_args,
+        )
 
+        streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps)
+        tts_gen_kwargs = {
+            "streamer": streamer,
+            **tts_gen_kwargs
+        }
         torch.manual_seed(0)
         thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
         thread.start()
diff --git a/utils.py b/utils.py
index 4a3c621..6ad672c 100644
--- a/utils.py
+++ b/utils.py
@@ -12,16 +12,9 @@ from time import perf_counter
 from parler_tts import ParlerTTSForConditionalGeneration
 from transformers.generation.streamers import BaseStreamer
 
-# def get_perf_counter(device):
-#     if device == "cpu":
-#         return perf_counter()
-    
-#     elif "cuda" in device:
-
-
-#     else:
-#         raise NotImplementedError(f"{device} not handled")
 
+def next_power_of_2(x):  
+    return 1 if x == 0 else 2**(x - 1).bit_length()
 
 
 def int2float(sound):
-- 
GitLab