Skip to content
Snippets Groups Projects
Commit f7dbb827 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

add whisper compile

parent 436c74ca
Branches
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ from queue import Queue
from time import perf_counter
import sys
import os
from pathlib import Path
from dataclasses import dataclass, field
from copy import copy
import multiprocessing
......@@ -35,8 +36,15 @@ from utils import (
)
console = Console()
# caching allows ~50% compilation time reduction
# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
CURRENT_DIR = Path(__file__).resolve().parent
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of compiled models
torch._dynamo.config.cache_size_limit = 15
console = Console()
@dataclass
class ModuleArguments:
......@@ -241,7 +249,7 @@ class VADHandlerArguments:
}
)
min_silence_ms: int = field(
default=1000,
default=250,
metadata={
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms."
}
......@@ -328,6 +336,12 @@ class WhisperSTTHandlerArguments:
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}
)
stt_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
}
)
stt_gen_max_new_tokens: int = field(
default=128,
metadata={
......@@ -357,7 +371,7 @@ class WhisperSTTHandlerArguments:
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
}
)
)
class WhisperSTTHandler(BaseHandler):
......@@ -366,8 +380,10 @@ class WhisperSTTHandler(BaseHandler):
model_name="distil-whisper/distil-large-v3",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={}
):
):
self.compile_mode=compile_mode
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
......@@ -377,6 +393,38 @@ class WhisperSTTHandler(BaseHandler):
).to(device)
self.gen_kwargs = gen_kwargs
# compile
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
self.warmup()
def prepare_model_inputs(self, spoken_prompt):
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features
def warmup(self):
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
logger.debug(f"Warming up {self.__class__.__name__}")
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **self.gen_kwargs)
end_event.record()
torch.cuda.synchronize()
logger.debug(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
def process(self, spoken_prompt):
global pipeline_start
pipeline_start = perf_counter()
......@@ -542,7 +590,7 @@ class ParlerTTSHandlerArguments:
}
)
play_steps_s: float = field(
default=0.5,
default=0.2,
metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
}
......@@ -670,6 +718,10 @@ def main():
)
logger = logging.getLogger(__name__)
# torch compile logs
if module_kwargs.log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "llm")
prepare_args(parler_tts_handler_kwargs, "tts")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment