diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index dbb6090c2ddeacdd9d9e976de95362655acd8700..4ee0bf288f50376e2900e7d705088f4c1aad09d6 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -6,6 +6,7 @@ from queue import Queue
 from time import perf_counter
 import sys
 import os
+from pathlib import Path
 from dataclasses import dataclass, field
 from copy import copy
 import multiprocessing
@@ -35,8 +36,15 @@ from utils import (
 )
 
 
-console = Console()
+# caching allows ~50% compilation time reduction
+# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
+CURRENT_DIR = Path(__file__).resolve().parent
+os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") 
+torch._inductor.config.fx_graph_cache = True
+# mind about this parameter ! should be >= 2 * number of compiled models
+torch._dynamo.config.cache_size_limit = 15
 
+console = Console()
 
 @dataclass
 class ModuleArguments:
@@ -241,7 +249,7 @@ class VADHandlerArguments:
         }
     )
     min_silence_ms: int = field(
-        default=1000,
+        default=250,
         metadata={
             "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms."
         }
@@ -328,6 +336,12 @@ class WhisperSTTHandlerArguments:
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
         } 
     )
+    stt_compile_mode: str = field(
+        default=None,
+        metadata={
+            "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
+        }
+    )
     stt_gen_max_new_tokens: int = field(
         default=128,
         metadata={
@@ -357,7 +371,7 @@ class WhisperSTTHandlerArguments:
         metadata={
             "help": "The language of the speech to transcribe. Default is 'en' for English."
         }
-    ) 
+    )
 
 
 class WhisperSTTHandler(BaseHandler):
@@ -366,8 +380,10 @@ class WhisperSTTHandler(BaseHandler):
             model_name="distil-whisper/distil-large-v3",
             device="cuda",  
             torch_dtype="float16",  
+            compile_mode=None,
             gen_kwargs={}
-        ):
+        ): 
+        self.compile_mode=compile_mode
         self.processor = AutoProcessor.from_pretrained(model_name)
         self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
@@ -377,6 +393,38 @@ class WhisperSTTHandler(BaseHandler):
         ).to(device)
         self.gen_kwargs = gen_kwargs
 
+        # compile
+        if self.compile_mode:
+            self.model.generation_config.cache_implementation = "static"
+            self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
+        self.warmup()
+    
+    def prepare_model_inputs(self, spoken_prompt):
+        input_features = self.processor(
+            spoken_prompt, sampling_rate=16000, return_tensors="pt"
+        ).input_features
+        input_features = input_features.to(self.device, dtype=self.torch_dtype)
+        return input_features
+        
+    def warmup(self):
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
+        n_steps = 1 if self.compile_mode == "default" else 2
+        logger.debug(f"Warming up {self.__class__.__name__}")
+        dummy_input = torch.randn(
+            (1,  self.model.config.num_mel_bins, 3000),
+            dtype=self.torch_dtype,
+            device=self.device
+        ) 
+        start_event = torch.cuda.Event(enable_timing=True)
+        end_event = torch.cuda.Event(enable_timing=True)
+        torch.cuda.synchronize()
+        start_event.record()
+        for _ in range(n_steps):
+            _ = self.model.generate(dummy_input, **self.gen_kwargs)
+        end_event.record()
+        torch.cuda.synchronize()
+        logger.debug(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+
     def process(self, spoken_prompt):
         global pipeline_start
         pipeline_start = perf_counter()
@@ -542,7 +590,7 @@ class ParlerTTSHandlerArguments:
         }
     )
     play_steps_s: float = field(
-        default=0.5,
+        default=0.2,
         metadata={
             "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
         }
@@ -670,6 +718,10 @@ def main():
     )
     logger = logging.getLogger(__name__)
 
+    # torch compile logs
+    if module_kwargs.log_level == "debug":
+        torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
+
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(language_model_handler_kwargs, "llm")
     prepare_args(parler_tts_handler_kwargs, "tts")