diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..acf6c1af2a5d2b8936c2574506ed33cdeee6dc4b
--- /dev/null
+++ b/STT/lightning_whisper_mlx_handler.py
@@ -0,0 +1,56 @@
+import logging
+from time import perf_counter
+from baseHandler import BaseHandler
+from lightning_whisper_mlx import LightningWhisperMLX
+import numpy as np
+from rich.console import Console
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class LightningWhisperSTTHandler(BaseHandler):
+    """
+    Handles the Speech To Text generation using a Whisper model.
+    """
+
+    def setup(
+        self,
+        model_name="distil-whisper/distil-large-v3",
+        device="cuda",
+        torch_dtype="float16",
+        compile_mode=None,
+        gen_kwargs={},
+    ):
+        self.device = device
+        self.model = LightningWhisperMLX(
+            model="distil-medium.en", batch_size=12, quant=None
+        )
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
+        n_steps = 1
+        dummy_input = np.array([0] * 512)
+
+        for _ in range(n_steps):
+            _ = self.model.transcribe(dummy_input)["text"].strip()
+
+    def process(self, spoken_prompt):
+        logger.debug("infering whisper...")
+
+        global pipeline_start
+        pipeline_start = perf_counter()
+
+        pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
+
+        logger.debug("finished whisper inference")
+        console.print(f"[yellow]USER: {pred_text}")
+
+        yield pred_text
diff --git a/TTS/melotts.py b/TTS/melotts.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef660105611c99581bbb07e32184068f38cb54e
--- /dev/null
+++ b/TTS/melotts.py
@@ -0,0 +1,53 @@
+from MeloTTS.melo.api import TTS
+import logging
+from baseHandler import BaseHandler
+import librosa
+import numpy as np
+from rich.console import Console
+import torch
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class MeloTTSHandler(BaseHandler):
+    def setup(
+        self,
+        should_listen,
+        device="mps",
+        language="EN_NEWEST",
+        blocksize=512,
+    ):
+        self.should_listen = should_listen
+        self.device = device
+        self.model = TTS(language=language, device=device)
+        self.speaker_id = self.model.hps.data.spk2id["EN-Newest"]
+        self.blocksize = blocksize
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+        _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
+
+    def process(self, llm_sentence):
+        console.print(f"[green]ASSISTANT: {llm_sentence}")
+        if self.device == "mps":
+            torch.mps.synchronize()  # Waits for all kernels in all streams on the MPS device to complete.
+
+        audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
+        if len(audio_chunk) == 0:
+            self.should_listen.set()
+            return
+        audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
+        audio_chunk = (audio_chunk * 32768).astype(np.int16)
+        for i in range(0, len(audio_chunk), self.blocksize):
+            yield np.pad(
+                audio_chunk[i : i + self.blocksize],
+                (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
+            )
+
+        self.should_listen.set()
diff --git a/baseHandler.py b/baseHandler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f5efa80adb473cff0b50336ddbc744fe50ba27c
--- /dev/null
+++ b/baseHandler.py
@@ -0,0 +1,51 @@
+from time import perf_counter
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class BaseHandler:
+    """
+    Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
+    The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
+    To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
+    Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
+    The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
+    """
+
+    def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
+        self.stop_event = stop_event
+        self.queue_in = queue_in
+        self.queue_out = queue_out
+        self.setup(*setup_args, **setup_kwargs)
+        self._times = []
+
+    def setup(self):
+        pass
+
+    def process(self):
+        raise NotImplementedError
+
+    def run(self):
+        while not self.stop_event.is_set():
+            input = self.queue_in.get()
+            if isinstance(input, bytes) and input == b"END":
+                # sentinelle signal to avoid queue deadlock
+                logger.debug("Stopping thread")
+                break
+            start_time = perf_counter()
+            for output in self.process(input):
+                self._times.append(perf_counter() - start_time)
+                logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
+                self.queue_out.put(output)
+                start_time = perf_counter()
+
+        self.cleanup()
+        self.queue_out.put(b"END")
+
+    @property
+    def last_time(self):
+        return self._times[-1]
+
+    def cleanup(self):
+        pass
diff --git a/listen_and_play.py b/listen_and_play.py
index 675f78b84a39c41f41f05b51b141fde87f105dfc..8fc0dfec09eeab4bbb465078031bd4ec9a9a82d5 100644
--- a/listen_and_play.py
+++ b/listen_and_play.py
@@ -1,5 +1,4 @@
 import socket
-import numpy as np
 import threading
 from queue import Queue
 from dataclasses import dataclass, field
@@ -9,41 +8,25 @@ from transformers import HfArgumentParser
 
 @dataclass
 class ListenAndPlayArguments:
-    send_rate: int = field(
-        default=16000,
-        metadata={
-            "help": "In Hz. Default is 16000."
-        }
-    )
-    recv_rate: int = field(
-        default=44100,
-        metadata={
-            "help": "In Hz. Default is 44100."
-        }
-    )
+    send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."})
+    recv_rate: int = field(default=44100, metadata={"help": "In Hz. Default is 44100."})
     list_play_chunk_size: int = field(
         default=1024,
-        metadata={
-            "help": "The size of data chunks (in bytes). Default is 1024."
-        }
+        metadata={"help": "The size of data chunks (in bytes). Default is 1024."},
     )
     host: str = field(
         default="localhost",
         metadata={
             "help": "The hostname or IP address for listening and playing. Default is 'localhost'."
-        }
+        },
     )
     send_port: int = field(
         default=12345,
-        metadata={
-            "help": "The network port for sending data. Default is 12345."
-        }
+        metadata={"help": "The network port for sending data. Default is 12345."},
     )
     recv_port: int = field(
         default=12346,
-        metadata={
-            "help": "The network port for receiving data. Default is 12346."
-        }
+        metadata={"help": "The network port for receiving data. Default is 12346."},
     )
 
 
@@ -55,7 +38,6 @@ def listen_and_play(
     send_port=12345,
     recv_port=12346,
 ):
-  
     send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     send_socket.connect((host, send_port))
 
@@ -68,13 +50,13 @@ def listen_and_play(
     recv_queue = Queue()
     send_queue = Queue()
 
-    def callback_recv(outdata, frames, time, status): 
+    def callback_recv(outdata, frames, time, status):
         if not recv_queue.empty():
             data = recv_queue.get()
-            outdata[:len(data)] = data 
-            outdata[len(data):] = b'\x00' * (len(outdata) - len(data)) 
+            outdata[: len(data)] = data
+            outdata[len(data) :] = b"\x00" * (len(outdata) - len(data))
         else:
-            outdata[:] = b'\x00' * len(outdata) 
+            outdata[:] = b"\x00" * len(outdata)
 
     def callback_send(indata, frames, time, status):
         if recv_queue.empty():
@@ -85,11 +67,10 @@ def listen_and_play(
         while not stop_event.is_set():
             data = send_queue.get()
             send_socket.sendall(data)
-        
-    def recv(stop_event, recv_queue):
 
+    def recv(stop_event, recv_queue):
         def receive_full_chunk(conn, chunk_size):
-            data = b''
+            data = b""
             while len(data) < chunk_size:
                 packet = conn.recv(chunk_size - len(data))
                 if not packet:
@@ -98,13 +79,25 @@ def listen_and_play(
             return data
 
         while not stop_event.is_set():
-            data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) 
+            data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
             if data:
                 recv_queue.put(data)
 
-    try: 
-        send_stream = sd.RawInputStream(samplerate=send_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_send)
-        recv_stream = sd.RawOutputStream(samplerate=recv_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_recv)
+    try:
+        send_stream = sd.RawInputStream(
+            samplerate=send_rate,
+            channels=1,
+            dtype="int16",
+            blocksize=list_play_chunk_size,
+            callback=callback_send,
+        )
+        recv_stream = sd.RawOutputStream(
+            samplerate=recv_rate,
+            channels=1,
+            dtype="int16",
+            blocksize=list_play_chunk_size,
+            callback=callback_recv,
+        )
         threading.Thread(target=send_stream.start).start()
         threading.Thread(target=recv_stream.start).start()
 
@@ -112,7 +105,7 @@ def listen_and_play(
         send_thread.start()
         recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue))
         recv_thread.start()
-        
+
         input("Press Enter to stop...")
 
     except KeyboardInterrupt:
@@ -129,6 +122,5 @@ def listen_and_play(
 
 if __name__ == "__main__":
     parser = HfArgumentParser((ListenAndPlayArguments,))
-    listen_and_play_kwargs, = parser.parse_args_into_dataclasses()
+    (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
     listen_and_play(**vars(listen_and_play_kwargs))
-
diff --git a/local_audio_streamer.py b/local_audio_streamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5203785aa6094e692c511540a679d239e03bae33
--- /dev/null
+++ b/local_audio_streamer.py
@@ -0,0 +1,38 @@
+import threading
+import sounddevice as sd
+import numpy as np
+
+import time
+
+
+class LocalAudioStreamer:
+    def __init__(
+        self,
+        input_queue,
+        output_queue,
+        list_play_chunk_size=512,
+    ):
+        self.list_play_chunk_size = list_play_chunk_size
+
+        self.stop_event = threading.Event()
+        self.input_queue = input_queue
+        self.output_queue = output_queue
+
+    def run(self):
+        def callback(indata, outdata, frames, time, status):
+            if self.output_queue.empty():
+                self.input_queue.put(indata.copy())
+                outdata[:] = 0 * outdata
+            else:
+                outdata[:] = self.output_queue.get()[:, np.newaxis]
+
+        with sd.Stream(
+            samplerate=16000,
+            dtype="int16",
+            channels=1,
+            callback=callback,
+            blocksize=self.list_play_chunk_size,
+        ):
+            while not self.stop_event.is_set():
+                time.sleep(0.001)
+            print("Stopping recording")
diff --git a/requirements.txt b/requirements.txt
index e04ebea181f5323d7ce91c7a916c1c34c6d45ef8..b928d18a342987752c2c2f1851939562abdeadf3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,3 +2,4 @@ nltk==3.8.1
 parler_tts @ git+https://github.com/huggingface/parler-tts.git
 torch==2.4.0
 sounddevice==0.5.0
+lightning-whisper-mlx==0.0.10
\ No newline at end of file
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index b367bb1ca929fae4339f00dca971c4a56ef19835..8c36a5d3f47c351a47c98283783c18c9f110ca16 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -3,17 +3,19 @@ import os
 import socket
 import sys
 import threading
-from collections import deque
 from copy import copy
 from dataclasses import dataclass, field
 from pathlib import Path
 from queue import Queue
 from threading import Event, Thread
 from time import perf_counter
+from typing import Optional
 
+from TTS.melotts import MeloTTSHandler
+from baseHandler import BaseHandler
+from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
 import numpy as np
 import torch
-import nltk
 from nltk.tokenize import sent_tokenize
 from rich.console import Console
 from transformers import (
@@ -23,19 +25,13 @@ from transformers import (
     AutoTokenizer,
     HfArgumentParser,
     pipeline,
-    TextIteratorStreamer
+    TextIteratorStreamer,
 )
+from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
+import librosa
 
-from parler_tts import (
-    ParlerTTSForConditionalGeneration,
-    ParlerTTSStreamer
-)
-
-from utils import (
-    VADIterator,
-    int2float,
-    next_power_of_2
-)
+from local_audio_streamer import LocalAudioStreamer
+from utils import VADIterator, int2float, next_power_of_2
 
 # Ensure that the necessary NLTK resources are available
 try:
@@ -46,10 +42,10 @@ except (LookupError, OSError):
 # caching allows ~50% compilation time reduction
 # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
 CURRENT_DIR = Path(__file__).resolve().parent
-os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") 
-torch._inductor.config.fx_graph_cache = True
-# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
-torch._dynamo.config.cache_size_limit = 15
+os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
+# torch._inductor.config.fx_graph_cache = True
+# # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
+# torch._dynamo.config.cache_size_limit = 15
 
 
 console = Console()
@@ -57,11 +53,21 @@ console = Console()
 
 @dataclass
 class ModuleArguments:
+    device: Optional[str] = field(
+        default=None,
+        metadata={"help": "If specified, overrides the device for all handlers."},
+    )
+    mode: Optional[str] = field(
+        default="local",
+        metadata={
+            "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
+        },
+    )
     log_level: str = field(
         default="info",
         metadata={
             "help": "Provide logging level. Example --log_level debug, default=warning."
-        }
+        },
     )
 
 
@@ -87,73 +93,26 @@ class ThreadManager:
             thread.join()
 
 
-class BaseHandler:
-    """
-    Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
-    The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
-    To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
-    Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
-    The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
-    """
-
-    def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
-        self.stop_event = stop_event
-        self.queue_in = queue_in
-        self.queue_out = queue_out
-        self.setup(*setup_args, **setup_kwargs)
-        self._times = []
-
-    def setup(self):
-        pass
-
-    def process(self):
-        raise NotImplementedError
-
-    def run(self):
-        while not self.stop_event.is_set():
-            input = self.queue_in.get()
-            if isinstance(input, bytes) and input == b'END':
-                # sentinelle signal to avoid queue deadlock
-                logger.debug("Stopping thread")
-                break
-            start_time = perf_counter()
-            for output in self.process(input):
-                self._times.append(perf_counter() - start_time)
-                logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
-                self.queue_out.put(output)
-                start_time = perf_counter()
-
-        self.cleanup()
-        self.queue_out.put(b'END')
-
-    @property
-    def last_time(self):
-        return self._times[-1]
-
-    def cleanup(self):
-        pass
-
-
 @dataclass
 class SocketReceiverArguments:
     recv_host: str = field(
         default="localhost",
         metadata={
             "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
-                    "available interfaces on the host machine."
-        }
+            "available interfaces on the host machine."
+        },
     )
     recv_port: int = field(
         default=12345,
         metadata={
             "help": "The port number on which the socket server listens. Default is 12346."
-        }
+        },
     )
     chunk_size: int = field(
         default=1024,
         metadata={
             "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
-        }
+        },
     )
 
 
@@ -163,28 +122,28 @@ class SocketReceiver:
     """
 
     def __init__(
-        self, 
+        self,
         stop_event,
         queue_out,
         should_listen,
-        host='0.0.0.0', 
+        host="0.0.0.0",
         port=12345,
-        chunk_size=1024
-    ):  
+        chunk_size=1024,
+    ):
         self.stop_event = stop_event
         self.queue_out = queue_out
         self.should_listen = should_listen
-        self.chunk_size=chunk_size
+        self.chunk_size = chunk_size
         self.host = host
         self.port = port
 
     def receive_full_chunk(self, conn, chunk_size):
-        data = b''
+        data = b""
         while len(data) < chunk_size:
             packet = conn.recv(chunk_size - len(data))
             if not packet:
                 # connection closed
-                return None  
+                return None
             data += packet
         return data
 
@@ -193,7 +152,7 @@ class SocketReceiver:
         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         self.socket.bind((self.host, self.port))
         self.socket.listen(1)
-        logger.info('Receiver waiting to be connected...')
+        logger.info("Receiver waiting to be connected...")
         self.conn, _ = self.socket.accept()
         logger.info("receiver connected")
 
@@ -202,7 +161,7 @@ class SocketReceiver:
             audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
             if audio_chunk is None:
                 # connection closed
-                self.queue_out.put(b'END')
+                self.queue_out.put(b"END")
                 break
             if self.should_listen.is_set():
                 self.queue_out.put(audio_chunk)
@@ -216,48 +175,41 @@ class SocketSenderArguments:
         default="localhost",
         metadata={
             "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
-                    "available interfaces on the host machine."
-        }
+            "available interfaces on the host machine."
+        },
     )
     send_port: int = field(
         default=12346,
         metadata={
             "help": "The port number on which the socket server listens. Default is 12346."
-        }
+        },
     )
 
-            
+
 class SocketSender:
     """
     Handles sending generated audio packets to the clients.
     """
 
-    def __init__(
-        self, 
-        stop_event,
-        queue_in,
-        host='0.0.0.0', 
-        port=12346
-    ):
+    def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
         self.stop_event = stop_event
         self.queue_in = queue_in
         self.host = host
         self.port = port
-        
 
     def run(self):
         self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         self.socket.bind((self.host, self.port))
         self.socket.listen(1)
-        logger.info('Sender waiting to be connected...')
+        logger.info("Sender waiting to be connected...")
         self.conn, _ = self.socket.accept()
         logger.info("sender connected")
 
         while not self.stop_event.is_set():
             audio_chunk = self.queue_in.get()
             self.conn.sendall(audio_chunk)
-            if isinstance(audio_chunk, bytes) and audio_chunk == b'END':
+            if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
                 break
         self.conn.close()
         logger.info("Sender closed")
@@ -269,37 +221,37 @@ class VADHandlerArguments:
         default=0.3,
         metadata={
             "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
-        }
+        },
     )
     sample_rate: int = field(
         default=16000,
         metadata={
             "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
-        }
+        },
     )
     min_silence_ms: int = field(
         default=250,
         metadata={
-            "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms."
-        }
+            "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
+        },
     )
     min_speech_ms: int = field(
-        default=750,
+        default=500,
         metadata={
             "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
-        }
+        },
     )
     max_speech_ms: float = field(
-        default=float('inf'),
+        default=float("inf"),
         metadata={
             "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
-        }
+        },
     )
     speech_pad_ms: int = field(
-        default=30,
+        default=250,
         metadata={
-            "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 30 ms."
-        }
+            "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
+        },
     )
 
 
@@ -310,22 +262,21 @@ class VADHandler(BaseHandler):
     """
 
     def setup(
-            self, 
-            should_listen,
-            thresh=0.3, 
-            sample_rate=16000, 
-            min_silence_ms=1000,
-            min_speech_ms=500, 
-            max_speech_ms=float('inf'),
-            speech_pad_ms=30,
-
-        ):
+        self,
+        should_listen,
+        thresh=0.3,
+        sample_rate=16000,
+        min_silence_ms=1000,
+        min_speech_ms=500,
+        max_speech_ms=float("inf"),
+        speech_pad_ms=30,
+    ):
         self.should_listen = should_listen
         self.sample_rate = sample_rate
         self.min_silence_ms = min_silence_ms
         self.min_speech_ms = min_speech_ms
         self.max_speech_ms = max_speech_ms
-        self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad')
+        self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
         self.iterator = VADIterator(
             self.model,
             threshold=thresh,
@@ -343,7 +294,9 @@ class VADHandler(BaseHandler):
             array = torch.cat(vad_output).cpu().numpy()
             duration_ms = len(array) / self.sample_rate * 1000
             if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
-                logger.debug(f"audio input of duration: {len(array) / self.sample_rate}s, skipping")
+                logger.debug(
+                    f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
+                )
             else:
                 self.should_listen.clear()
                 logger.debug("Stop listening")
@@ -356,56 +309,56 @@ class WhisperSTTHandlerArguments:
         default="distil-whisper/distil-large-v3",
         metadata={
             "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
-        }
+        },
     )
     stt_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
-        }
+        },
     )
     stt_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
-        } 
+        },
     )
     stt_compile_mode: str = field(
         default=None,
         metadata={
             "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
-        }
+        },
     )
     stt_gen_max_new_tokens: int = field(
         default=128,
         metadata={
             "help": "The maximum number of new tokens to generate. Default is 128."
-        }
+        },
     )
     stt_gen_num_beams: int = field(
         default=1,
         metadata={
             "help": "The number of beams for beam search. Default is 1, implying greedy decoding."
-        }
+        },
     )
     stt_gen_return_timestamps: bool = field(
         default=False,
         metadata={
             "help": "Whether to return timestamps with transcriptions. Default is False."
-        }
-    )
-    stt_gen_task: str = field(
-        default="transcribe",
-        metadata={
-            "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
-        }
-    )
-    stt_gen_language: str = field(
-        default="en",
-        metadata={
-            "help": "The language of the speech to transcribe. Default is 'en' for English."
-        }
+        },
     )
+    # stt_gen_task: str = field(
+    #     default="transcribe",
+    #     metadata={
+    #         "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
+    #     },
+    # )
+    # stt_gen_language: str = field(
+    #     default="en",
+    #     metadata={
+    #         "help": "The language of the speech to transcribe. Default is 'en' for English."
+    #     },
+    # )
 
 
 class WhisperSTTHandler(BaseHandler):
@@ -414,16 +367,16 @@ class WhisperSTTHandler(BaseHandler):
     """
 
     def setup(
-            self,
-            model_name="distil-whisper/distil-large-v3",
-            device="cuda",  
-            torch_dtype="float16",  
-            compile_mode=None,
-            gen_kwargs={}
-        ): 
+        self,
+        model_name="distil-whisper/distil-large-v3",
+        device="cuda",
+        torch_dtype="float16",
+        compile_mode=None,
+        gen_kwargs={},
+    ):
         self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
-        self.compile_mode=compile_mode
+        self.compile_mode = compile_mode
         self.gen_kwargs = gen_kwargs
 
         self.processor = AutoProcessor.from_pretrained(model_name)
@@ -431,13 +384,15 @@ class WhisperSTTHandler(BaseHandler):
             model_name,
             torch_dtype=self.torch_dtype,
         ).to(device)
-        
+
         # compile
         if self.compile_mode:
             self.model.generation_config.cache_implementation = "static"
-            self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
+            self.model.forward = torch.compile(
+                self.model.forward, mode=self.compile_mode, fullgraph=True
+            )
         self.warmup()
-    
+
     def prepare_model_inputs(self, spoken_prompt):
         input_features = self.processor(
             spoken_prompt, sampling_rate=16000, return_tensors="pt"
@@ -445,39 +400,44 @@ class WhisperSTTHandler(BaseHandler):
         input_features = input_features.to(self.device, dtype=self.torch_dtype)
 
         return input_features
-        
+
     def warmup(self):
         logger.info(f"Warming up {self.__class__.__name__}")
 
-        # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
         n_steps = 1 if self.compile_mode == "default" else 2
         dummy_input = torch.randn(
-            (1,  self.model.config.num_mel_bins, 3000),
+            (1, self.model.config.num_mel_bins, 3000),
             dtype=self.torch_dtype,
-            device=self.device
-        ) 
+            device=self.device,
+        )
         if self.compile_mode not in (None, "default"):
             # generating more tokens than previously will trigger CUDA graphs capture
             # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
             warmup_gen_kwargs = {
                 "min_new_tokens": self.gen_kwargs["max_new_tokens"],
                 "max_new_tokens": self.gen_kwargs["max_new_tokens"],
-                **self.gen_kwargs
+                **self.gen_kwargs,
             }
         else:
             warmup_gen_kwargs = self.gen_kwargs
 
-        start_event = torch.cuda.Event(enable_timing=True)
-        end_event = torch.cuda.Event(enable_timing=True)
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            torch.cuda.synchronize()
+            start_event.record()
 
-        torch.cuda.synchronize()
-        start_event.record()
         for _ in range(n_steps):
             _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
-        end_event.record()
-        torch.cuda.synchronize()
 
-        logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
 
     def process(self, spoken_prompt):
         logger.debug("infering whisper...")
@@ -488,9 +448,7 @@ class WhisperSTTHandler(BaseHandler):
         input_features = self.prepare_model_inputs(spoken_prompt)
         pred_ids = self.model.generate(input_features, **self.gen_kwargs)
         pred_text = self.processor.batch_decode(
-            pred_ids, 
-            skip_special_tokens=True,
-            decode_with_timestamps=False
+            pred_ids, skip_special_tokens=True, decode_with_timestamps=False
         )[0]
 
         logger.debug("finished whisper inference")
@@ -502,56 +460,64 @@ class WhisperSTTHandler(BaseHandler):
 @dataclass
 class LanguageModelHandlerArguments:
     lm_model_name: str = field(
-        default="microsoft/Phi-3-mini-4k-instruct",
+        default="HuggingFaceTB/SmolLM-360M-Instruct",
         metadata={
             "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
-        }
+        },
     )
     lm_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
-        }
+        },
     )
     lm_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
-        }
+        },
     )
     user_role: str = field(
         default="user",
         metadata={
             "help": "Role assigned to the user in the chat context. Default is 'user'."
-        }
+        },
     )
     init_chat_role: str = field(
         default=None,
         metadata={
             "help": "Initial role for setting up the chat context. Default is 'system'."
-        }
+        },
     )
     init_chat_prompt: str = field(
         default="You are a helpful AI assistant.",
         metadata={
             "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
-        }
+        },
     )
     lm_gen_max_new_tokens: int = field(
         default=64,
-        metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."}
+        metadata={
+            "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
+        },
     )
     lm_gen_temperature: float = field(
         default=0.0,
-        metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."}
+        metadata={
+            "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
+        },
     )
     lm_gen_do_sample: bool = field(
         default=False,
-        metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."}
+        metadata={
+            "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
+        },
     )
     chat_size: int = field(
         default=1,
-        metadata={"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."}
+        metadata={
+            "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
+        },
     )
 
 
@@ -570,7 +536,7 @@ class Chat:
         self.buffer.append(item)
         if len(self.buffer) == 2 * (self.size + 1):
             self.buffer.pop(0)
-            self.buffer.pop(0)   
+            self.buffer.pop(0)
 
     def init_chat(self, init_chat_message):
         self.init_chat_message = init_chat_message
@@ -584,34 +550,30 @@ class Chat:
 
 class LanguageModelHandler(BaseHandler):
     """
-    Handles the language model part. 
+    Handles the language model part.
     """
 
     def setup(
-            self,
-            model_name="microsoft/Phi-3-mini-4k-instruct",
-            device="cuda", 
-            torch_dtype="float16",
-            gen_kwargs={},
-            user_role="user",
-            chat_size=1,
-            init_chat_role=None, 
-            init_chat_prompt="You are a helpful AI assistant.",
-        ):
+        self,
+        model_name="microsoft/Phi-3-mini-4k-instruct",
+        device="cuda",
+        torch_dtype="float16",
+        gen_kwargs={},
+        user_role="user",
+        chat_size=1,
+        init_chat_role=None,
+        init_chat_prompt="You are a helpful AI assistant.",
+    ):
         self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
 
         self.tokenizer = AutoTokenizer.from_pretrained(model_name)
         self.model = AutoModelForCausalLM.from_pretrained(
-            model_name,
-            torch_dtype=torch_dtype,
-            trust_remote_code=True
+            model_name, torch_dtype=torch_dtype, trust_remote_code=True
         ).to(device)
-        self.pipe = pipeline( 
-            "text-generation", 
-            model=self.model, 
-            tokenizer=self.tokenizer, 
-        ) 
+        self.pipe = pipeline(
+            "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
+        )
         self.streamer = TextIteratorStreamer(
             self.tokenizer,
             skip_prompt=True,
@@ -620,16 +582,16 @@ class LanguageModelHandler(BaseHandler):
         self.gen_kwargs = {
             "streamer": self.streamer,
             "return_full_text": False,
-            **gen_kwargs
+            **gen_kwargs,
         }
 
         self.chat = Chat(chat_size)
         if init_chat_role:
             if not init_chat_prompt:
-                raise ValueError(f"An initial promt needs to be specified when setting init_chat_role.")
-            self.chat.init_chat(
-                {"role": init_chat_role, "content": init_chat_prompt}
-            )
+                raise ValueError(
+                    "An initial promt needs to be specified when setting init_chat_role."
+                )
+            self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
         self.user_role = user_role
 
         self.warmup()
@@ -642,47 +604,57 @@ class LanguageModelHandler(BaseHandler):
         warmup_gen_kwargs = {
             "min_new_tokens": self.gen_kwargs["max_new_tokens"],
             "max_new_tokens": self.gen_kwargs["max_new_tokens"],
-            **self.gen_kwargs
+            **self.gen_kwargs,
         }
 
         n_steps = 2
 
-        start_event = torch.cuda.Event(enable_timing=True)
-        end_event = torch.cuda.Event(enable_timing=True)
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            torch.cuda.synchronize()
+            start_event.record()
 
-        torch.cuda.synchronize()
-        start_event.record()
         for _ in range(n_steps):
-            thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs)
+            thread = Thread(
+                target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
+            )
             thread.start()
-            for _ in self.streamer: 
-                pass    
-        end_event.record()
-        torch.cuda.synchronize()
+            for _ in self.streamer:
+                pass
 
-        logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
 
     def process(self, prompt):
         logger.debug("infering language model...")
 
-        self.chat.append(
-            {"role": self.user_role, "content": prompt}
+        self.chat.append({"role": self.user_role, "content": prompt})
+        thread = Thread(
+            target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
         )
-        thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs)
         thread.start()
+        if self.device == "mps":
+            generated_text = ""
+            for new_text in self.streamer:
+                generated_text += new_text
+            printable_text = generated_text
+        else:
+            generated_text, printable_text = "", ""
+            for new_text in self.streamer:
+                generated_text += new_text
+                printable_text += new_text
+                sentences = sent_tokenize(printable_text)
+                if len(sentences) > 1:
+                    yield (sentences[0])
+                    printable_text = new_text
 
-        generated_text, printable_text = "", ""
-        for new_text in self.streamer:
-            generated_text += new_text
-            printable_text += new_text
-            sentences = sent_tokenize(printable_text)
-            if len(sentences) > 1:
-                yield(sentences[0])
-                printable_text = new_text
-
-        self.chat.append(
-            {"role": "assistant", "content": generated_text}
-        )
+        self.chat.append({"role": "assistant", "content": generated_text})
 
         # don't forget last sentence
         yield printable_text
@@ -694,33 +666,37 @@ class ParlerTTSHandlerArguments:
         default="ylacombe/parler-tts-mini-jenny-30H",
         metadata={
             "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
-        }
+        },
     )
     tts_device: str = field(
         default="cuda",
         metadata={
             "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
-        }
+        },
     )
     tts_torch_dtype: str = field(
         default="float16",
         metadata={
             "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
-        }
+        },
     )
     tts_compile_mode: str = field(
         default=None,
         metadata={
             "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
-        }
+        },
     )
     tts_gen_min_new_tokens: int = field(
-        default=None,
-        metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"}
+        default=64,
+        metadata={
+            "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
+        },
     )
     tts_gen_max_new_tokens: int = field(
         default=512,
-        metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"}
+        metadata={
+            "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
+        },
     )
     description: str = field(
         default=(
@@ -729,38 +705,39 @@ class ParlerTTSHandlerArguments:
         ),
         metadata={
             "help": "Description of the speaker's voice and speaking style to guide the TTS model."
-        }
+        },
     )
     play_steps_s: float = field(
-        default=0.2,
+        default=1.0,
         metadata={
             "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
-        }
+        },
     )
     max_prompt_pad_length: int = field(
         default=8,
         metadata={
             "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
-        }
-    ) 
+        },
+    )
 
 
 class ParlerTTSHandler(BaseHandler):
     def setup(
-            self,
-            should_listen,
-            model_name="ylacombe/parler-tts-mini-jenny-30H",
-            device="cuda", 
-            torch_dtype="float16",
-            compile_mode=None,
-            gen_kwargs={},
-            max_prompt_pad_length=8,
-            description=(
-                "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
-                "She speaks very fast."
-            ),
-            play_steps_s=1
-        ):
+        self,
+        should_listen,
+        model_name="ylacombe/parler-tts-mini-jenny-30H",
+        device="cuda",
+        torch_dtype="float16",
+        compile_mode=None,
+        gen_kwargs={},
+        max_prompt_pad_length=8,
+        description=(
+            "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
+            "She speaks very fast."
+        ),
+        play_steps_s=1,
+        blocksize=512,
+    ):
         self.should_listen = should_listen
         self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
@@ -769,23 +746,27 @@ class ParlerTTSHandler(BaseHandler):
         self.max_prompt_pad_length = max_prompt_pad_length
         self.description = description
 
-        self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
+        self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
         self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
         self.model = ParlerTTSForConditionalGeneration.from_pretrained(
-            model_name,
-            torch_dtype=self.torch_dtype
+            model_name, torch_dtype=self.torch_dtype
         ).to(device)
-        
+
         framerate = self.model.audio_encoder.config.frame_rate
         self.play_steps = int(framerate * play_steps_s)
+        self.blocksize = blocksize
 
         if self.compile_mode not in (None, "default"):
-            logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'")
+            logger.warning(
+                "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
+            )
             self.compile_mode = "default"
 
         if self.compile_mode:
             self.model.generation_config.cache_implementation = "static"
-            self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
+            self.model.forward = torch.compile(
+                self.model.forward, mode=self.compile_mode, fullgraph=True
+            )
 
         self.warmup()
 
@@ -795,13 +776,19 @@ class ParlerTTSHandler(BaseHandler):
         max_length_prompt=50,
         pad=False,
     ):
-        pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
+        pad_args_prompt = (
+            {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
+        )
 
-        tokenized_description = self.description_tokenizer(self.description, return_tensors="pt")
+        tokenized_description = self.description_tokenizer(
+            self.description, return_tensors="pt"
+        )
         input_ids = tokenized_description.input_ids.to(self.device)
         attention_mask = tokenized_description.attention_mask.to(self.device)
 
-        tokenized_prompt = self.prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt)
+        tokenized_prompt = self.prompt_tokenizer(
+            prompt, return_tensors="pt", **pad_args_prompt
+        )
         prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
         prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
 
@@ -810,29 +797,29 @@ class ParlerTTSHandler(BaseHandler):
             "attention_mask": attention_mask,
             "prompt_input_ids": prompt_input_ids,
             "prompt_attention_mask": prompt_attention_mask,
-            **self.gen_kwargs
+            **self.gen_kwargs,
         }
 
         return gen_kwargs
-    
+
     def warmup(self):
         logger.info(f"Warming up {self.__class__.__name__}")
 
-        start_event = torch.cuda.Event(enable_timing=True)
-        end_event = torch.cuda.Event(enable_timing=True)
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
 
-        # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
         n_steps = 1 if self.compile_mode == "default" else 2
 
-        torch.cuda.synchronize()
-        start_event.record()
+        if self.device == "cuda":
+            torch.cuda.synchronize()
+            start_event.record()
         if self.compile_mode:
             pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
             for pad_length in pad_lengths[::-1]:
                 model_kwargs = self.prepare_model_inputs(
-                    "dummy prompt", 
-                    max_length_prompt=pad_length,
-                    pad=True
+                    "dummy prompt", max_length_prompt=pad_length, pad=True
                 )
                 for _ in range(n_steps):
                     _ = self.model.generate(**model_kwargs)
@@ -840,12 +827,14 @@ class ParlerTTSHandler(BaseHandler):
         else:
             model_kwargs = self.prepare_model_inputs("dummy prompt")
             for _ in range(n_steps):
-                    _ = self.model.generate(**model_kwargs)
-                
-        end_event.record() 
-        torch.cuda.synchronize()
-        logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+                _ = self.model.generate(**model_kwargs)
 
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
 
     def process(self, llm_sentence):
         console.print(f"[green]ASSISTANT: {llm_sentence}")
@@ -858,26 +847,32 @@ class ParlerTTSHandler(BaseHandler):
             logger.debug(f"padding to {pad_length}")
             pad_args["pad"] = True
             pad_args["max_length_prompt"] = pad_length
-    
+
         tts_gen_kwargs = self.prepare_model_inputs(
             llm_sentence,
             **pad_args,
         )
 
-        streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps)
-        tts_gen_kwargs = {
-            "streamer": streamer,
-            **tts_gen_kwargs
-        }
+        streamer = ParlerTTSStreamer(
+            self.model, device=self.device, play_steps=self.play_steps
+        )
+        tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
         torch.manual_seed(0)
         thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
         thread.start()
 
         for i, audio_chunk in enumerate(streamer):
             if i == 0:
-                logger.info(f"Time to first audio: {perf_counter() - pipeline_start:.3f}")
-            audio_chunk = np.int16(audio_chunk * 32767)
-            yield audio_chunk
+                logger.info(
+                    f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
+                )
+            audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
+            audio_chunk = (audio_chunk * 32768).astype(np.int16)
+            for i in range(0, len(audio_chunk), self.blocksize):
+                yield np.pad(
+                    audio_chunk[i : i + self.blocksize],
+                    (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
+                )
 
         self.should_listen.set()
 
@@ -891,7 +886,7 @@ def prepare_args(args, prefix):
     for key in copy(args.__dict__):
         if key.startswith(prefix):
             value = args.__dict__.pop(key)
-            new_key = key[len(prefix) + 1:]  # Remove prefix and underscore
+            new_key = key[len(prefix) + 1 :]  # Remove prefix and underscore
             if new_key.startswith("gen_"):
                 gen_kwargs[new_key[4:]] = value  # Remove 'gen_' and add to dict
             else:
@@ -901,37 +896,39 @@ def prepare_args(args, prefix):
 
 
 def main():
-    parser = HfArgumentParser((
-        ModuleArguments,
-        SocketReceiverArguments, 
-        SocketSenderArguments,
-        VADHandlerArguments,
-        WhisperSTTHandlerArguments,
-        LanguageModelHandlerArguments,
-        ParlerTTSHandlerArguments,
-    ))
+    parser = HfArgumentParser(
+        (
+            ModuleArguments,
+            SocketReceiverArguments,
+            SocketSenderArguments,
+            VADHandlerArguments,
+            WhisperSTTHandlerArguments,
+            LanguageModelHandlerArguments,
+            ParlerTTSHandlerArguments,
+        )
+    )
 
     # 0. Parse CLI arguments
     if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
         # Parse configurations from a JSON file if specified
         (
             module_kwargs,
-            socket_receiver_kwargs, 
-            socket_sender_kwargs, 
-            vad_handler_kwargs, 
-            whisper_stt_handler_kwargs, 
-            language_model_handler_kwargs, 
+            socket_receiver_kwargs,
+            socket_sender_kwargs,
+            vad_handler_kwargs,
+            whisper_stt_handler_kwargs,
+            language_model_handler_kwargs,
             parler_tts_handler_kwargs,
         ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
     else:
         # Parse arguments from command line if no JSON file is provided
         (
             module_kwargs,
-            socket_receiver_kwargs, 
-            socket_sender_kwargs, 
-            vad_handler_kwargs, 
-            whisper_stt_handler_kwargs, 
-            language_model_handler_kwargs, 
+            socket_receiver_kwargs,
+            socket_sender_kwargs,
+            vad_handler_kwargs,
+            whisper_stt_handler_kwargs,
+            language_model_handler_kwargs,
             parler_tts_handler_kwargs,
         ) = parser.parse_args_into_dataclasses()
 
@@ -939,7 +936,7 @@ def main():
     global logger
     logging.basicConfig(
         level=module_kwargs.log_level.upper(),
-        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
     )
     logger = logging.getLogger(__name__)
 
@@ -948,20 +945,62 @@ def main():
         torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
 
     # 2. Prepare each part's arguments
+    def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
+        if common_device:
+            for kwargs in handler_kwargs:
+                if hasattr(kwargs, "lm_device"):
+                    kwargs.lm_device = common_device
+                if hasattr(kwargs, "tts_device"):
+                    kwargs.tts_device = common_device
+                if hasattr(kwargs, "stt_device"):
+                    kwargs.stt_device = common_device
+
+    # Call this function with the common device and all the handlers
+    overwrite_device_argument(
+        module_kwargs.device,
+        language_model_handler_kwargs,
+        parler_tts_handler_kwargs,
+        whisper_stt_handler_kwargs,
+    )
+
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(language_model_handler_kwargs, "lm")
-    prepare_args(parler_tts_handler_kwargs, "tts") 
+    prepare_args(parler_tts_handler_kwargs, "tts")
 
     # 3. Build the pipeline
     stop_event = Event()
     # used to stop putting received audio chunks in queue until all setences have been processed by the TTS
-    should_listen = Event() 
+    should_listen = Event()
     recv_audio_chunks_queue = Queue()
     send_audio_chunks_queue = Queue()
-    spoken_prompt_queue = Queue() 
+    spoken_prompt_queue = Queue()
     text_prompt_queue = Queue()
     lm_response_queue = Queue()
-    
+
+    if module_kwargs.mode == "local":
+        local_audio_streamer = LocalAudioStreamer(
+            input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
+        )
+        comms_handlers = [local_audio_streamer]
+        should_listen.set()
+    else:
+        comms_handlers = [
+            SocketReceiver(
+                stop_event,
+                recv_audio_chunks_queue,
+                should_listen,
+                host=socket_receiver_kwargs.recv_host,
+                port=socket_receiver_kwargs.recv_port,
+                chunk_size=socket_receiver_kwargs.chunk_size,
+            ),
+            SocketSender(
+                stop_event,
+                send_audio_chunks_queue,
+                host=socket_sender_kwargs.send_host,
+                port=socket_sender_kwargs.send_port,
+            ),
+        ]
+
     vad = VADHandler(
         stop_event,
         queue_in=recv_audio_chunks_queue,
@@ -969,7 +1008,7 @@ def main():
         setup_args=(should_listen,),
         setup_kwargs=vars(vad_handler_kwargs),
     )
-    stt = WhisperSTTHandler(
+    stt = LightningWhisperSTTHandler(
         stop_event,
         queue_in=spoken_prompt_queue,
         queue_out=text_prompt_queue,
@@ -981,37 +1020,28 @@ def main():
         queue_out=lm_response_queue,
         setup_kwargs=vars(language_model_handler_kwargs),
     )
-    tts = ParlerTTSHandler(
+    # tts = ParlerTTSHandler(
+    #     stop_event,
+    #     queue_in=lm_response_queue,
+    #     queue_out=send_audio_chunks_queue,
+    #     setup_args=(should_listen,),
+    #     setup_kwargs=vars(parler_tts_handler_kwargs),
+    # )
+    tts = MeloTTSHandler(
         stop_event,
         queue_in=lm_response_queue,
         queue_out=send_audio_chunks_queue,
         setup_args=(should_listen,),
-        setup_kwargs=vars(parler_tts_handler_kwargs),
-    )  
-
-    recv_handler = SocketReceiver(
-        stop_event, 
-        recv_audio_chunks_queue, 
-        should_listen,
-        host=socket_receiver_kwargs.recv_host,
-        port=socket_receiver_kwargs.recv_port,
-        chunk_size=socket_receiver_kwargs.chunk_size,
     )
 
-    send_handler = SocketSender(
-        stop_event, 
-        send_audio_chunks_queue,
-        host=socket_sender_kwargs.send_host,
-        port=socket_sender_kwargs.send_port,
-        )
-
     # 4. Run the pipeline
     try:
-        pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler])
+        pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
         pipeline_manager.start()
 
     except KeyboardInterrupt:
         pipeline_manager.stop()
-    
+
+
 if __name__ == "__main__":
     main()
diff --git a/utils.py b/utils.py
index 6a353c7f136f0af6cc8e22c53b97642a0ad47e19..f4237a1121e5a398e09bb8249bd37a9bffa81b43 100644
--- a/utils.py
+++ b/utils.py
@@ -2,8 +2,8 @@ import numpy as np
 import torch
 
 
-def next_power_of_2(x):  
-    return 1 if x == 0 else 2**(x - 1).bit_length()
+def next_power_of_2(x):
+    return 1 if x == 0 else 2 ** (x - 1).bit_length()
 
 
 def int2float(sound):
@@ -12,22 +12,22 @@ def int2float(sound):
     """
 
     abs_max = np.abs(sound).max()
-    sound = sound.astype('float32')
+    sound = sound.astype("float32")
     if abs_max > 0:
-        sound *= 1/32768
+        sound *= 1 / 32768
     sound = sound.squeeze()  # depends on the use case
     return sound
 
 
 class VADIterator:
-    def __init__(self,
-                 model,
-                 threshold: float = 0.5,
-                 sampling_rate: int = 16000,
-                 min_silence_duration_ms: int = 100,
-                 speech_pad_ms: int = 30
-                 ):
-
+    def __init__(
+        self,
+        model,
+        threshold: float = 0.5,
+        sampling_rate: int = 16000,
+        min_silence_duration_ms: int = 100,
+        speech_pad_ms: int = 30,
+    ):
         """
         Mainly taken from https://github.com/snakers4/silero-vad
         Class for stream imitation
@@ -57,14 +57,15 @@ class VADIterator:
         self.buffer = []
 
         if sampling_rate not in [8000, 16000]:
-            raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
+            raise ValueError(
+                "VADIterator does not support sampling rates other than [8000, 16000]"
+            )
 
         self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
         self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
         self.reset_states()
 
     def reset_states(self):
-
         self.model.reset_states()
         self.triggered = False
         self.temp_end = 0
@@ -107,11 +108,11 @@ class VADIterator:
                 # end of speak
                 self.temp_end = 0
                 self.triggered = False
-                spoken_utterance = self.buffer 
+                spoken_utterance = self.buffer
                 self.buffer = []
                 return spoken_utterance
-            
+
         if self.triggered:
             self.buffer.append(x)
-                
+
         return None