diff --git a/s2s_pipeline.py b/s2s_pipeline.py index dbb6090c2ddeacdd9d9e976de95362655acd8700..4ee0bf288f50376e2900e7d705088f4c1aad09d6 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -6,6 +6,7 @@ from queue import Queue from time import perf_counter import sys import os +from pathlib import Path from dataclasses import dataclass, field from copy import copy import multiprocessing @@ -35,8 +36,15 @@ from utils import ( ) -console = Console() +# caching allows ~50% compilation time reduction +# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma +CURRENT_DIR = Path(__file__).resolve().parent +os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") +torch._inductor.config.fx_graph_cache = True +# mind about this parameter ! should be >= 2 * number of compiled models +torch._dynamo.config.cache_size_limit = 15 +console = Console() @dataclass class ModuleArguments: @@ -241,7 +249,7 @@ class VADHandlerArguments: } ) min_silence_ms: int = field( - default=1000, + default=250, metadata={ "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms." } @@ -328,6 +336,12 @@ class WhisperSTTHandlerArguments: "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." } ) + stt_compile_mode: str = field( + default=None, + metadata={ + "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" + } + ) stt_gen_max_new_tokens: int = field( default=128, metadata={ @@ -357,7 +371,7 @@ class WhisperSTTHandlerArguments: metadata={ "help": "The language of the speech to transcribe. Default is 'en' for English." } - ) + ) class WhisperSTTHandler(BaseHandler): @@ -366,8 +380,10 @@ class WhisperSTTHandler(BaseHandler): model_name="distil-whisper/distil-large-v3", device="cuda", torch_dtype="float16", + compile_mode=None, gen_kwargs={} - ): + ): + self.compile_mode=compile_mode self.processor = AutoProcessor.from_pretrained(model_name) self.device = device self.torch_dtype = getattr(torch, torch_dtype) @@ -377,6 +393,38 @@ class WhisperSTTHandler(BaseHandler): ).to(device) self.gen_kwargs = gen_kwargs + # compile + if self.compile_mode: + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) + self.warmup() + + def prepare_model_inputs(self, spoken_prompt): + input_features = self.processor( + spoken_prompt, sampling_rate=16000, return_tensors="pt" + ).input_features + input_features = input_features.to(self.device, dtype=self.torch_dtype) + return input_features + + def warmup(self): + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 1 if self.compile_mode == "default" else 2 + logger.debug(f"Warming up {self.__class__.__name__}") + dummy_input = torch.randn( + (1, self.model.config.num_mel_bins, 3000), + dtype=self.torch_dtype, + device=self.device + ) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(n_steps): + _ = self.model.generate(dummy_input, **self.gen_kwargs) + end_event.record() + torch.cuda.synchronize() + logger.debug(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") + def process(self, spoken_prompt): global pipeline_start pipeline_start = perf_counter() @@ -542,7 +590,7 @@ class ParlerTTSHandlerArguments: } ) play_steps_s: float = field( - default=0.5, + default=0.2, metadata={ "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." } @@ -670,6 +718,10 @@ def main(): ) logger = logging.getLogger(__name__) + # torch compile logs + if module_kwargs.log_level == "debug": + torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "llm") prepare_args(parler_tts_handler_kwargs, "tts")