From 0574f84029d4f0d6fa131bb5899100098567db33 Mon Sep 17 00:00:00 2001
From: Eustache Le Bihan <eulebihan@gmail.com>
Date: Tue, 13 Aug 2024 16:06:55 +0000
Subject: [PATCH] compile parler tts

---
 s2s_pipeline.py | 37 +++++++++++++++++++++----------------
 1 file changed, 21 insertions(+), 16 deletions(-)

diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 5c81b59..eebb2c6 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -301,7 +301,7 @@ class VADHandler(BaseHandler):
         audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
         audio_float32 = int2float(audio_int16)
         vad_output = self.iterator(torch.from_numpy(audio_float32))
-        if vad_output is not None:
+        if vad_output is not None and len(vad_output) != 0:
             logger.debug("VAD: end of speech detected")
             array = torch.cat(vad_output).cpu().numpy()
             duration_ms = len(array) / self._sample_rate * 1000
@@ -622,11 +622,13 @@ class ParlerTTSHandlerArguments:
             "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
         }
     )
-    gen_kwargs: dict = field(
-        default_factory=dict,
-        metadata={
-            "help": "Additional keyword arguments to pass to the model's generate method. Use this to customize generation settings."
-        }
+    tts_gen_min_new_tokens: int = field(
+        default=10,
+        metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"}
+    )
+    tts_gen_max_new_tokens: int = field(
+        default=512,
+        metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"}
     )
     description: str = field(
         default=(
@@ -643,6 +645,12 @@ class ParlerTTSHandlerArguments:
             "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
         }
     )
+    max_prompt_pad_length: int = field(
+        default=8,
+        metadata={
+            "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
+        }
+    ) 
 
 
 class ParlerTTSHandler(BaseHandler):
@@ -652,6 +660,7 @@ class ParlerTTSHandler(BaseHandler):
             model_name="ylacombe/parler-tts-mini-jenny-30H",
             device="cuda", 
             torch_dtype="float16",
+            max_prompt_pad_length=8,
             gen_kwargs={},
             compile_mode=None,
             description=(
@@ -660,6 +669,7 @@ class ParlerTTSHandler(BaseHandler):
             ),
             play_steps_s=1
         ):
+        self.max_prompt_pad_length = max_prompt_pad_length
         torch_dtype = getattr(torch, torch_dtype)
         self._should_listen = should_listen
         self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
@@ -681,6 +691,10 @@ class ParlerTTSHandler(BaseHandler):
         self.play_steps = int(framerate * play_steps_s)
 
         self.compile_mode = compile_mode
+        if self.compile_mode not in (None, "default"):
+            logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'")
+            self.compile_mode = "default"
+
         if self.compile_mode:
             self.model.generation_config.cache_implementation = "static"
             self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
@@ -712,7 +726,7 @@ class ParlerTTSHandler(BaseHandler):
         return gen_kwargs
     
     def warmup(self):
-        pad_lengths = [2**i for i in range(4, 9)]
+        pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
         for pad_length in pad_lengths[::-1]:
             model_kwargs = self.prepare_model_inputs(
                 "dummy prompt", 
@@ -722,15 +736,6 @@ class ParlerTTSHandler(BaseHandler):
             # 2 warmup steps for modes that capture CUDA graphs
             n_steps = 1 if self.compile_mode == "default" else 2
 
-            if self.compile_mode not in (None, "default"):
-                # generating more tokens than previously will trigger CUDA graphs capture
-                # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
-                model_kwargs = {
-                    "min_new_tokens": 86*3,
-                    "max_new_tokens": 86*3,
-                    **model_kwargs
-                }
-
             logger.info(f"Warming up length {pad_length} tokens...")
             start_event = torch.cuda.Event(enable_timing=True)
             end_event = torch.cuda.Event(enable_timing=True)
-- 
GitLab