From d50687a0c379a705a37f8fa26d65634a48ca6334 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Fri, 23 Aug 2024 16:57:38 +0200
Subject: [PATCH] refactor all the handlers - folder structure

---
 LLM/language_model.py                         | 134 ++++
 LLM/{mlx_lm.py => mlx_language_model.py}      |   8 +-
 STT/whisper_stt_handler.py                    | 113 ++++
 TTS/{melotts.py => melo_handler.py}           |   1 -
 TTS/parler_handler.py                         | 181 ++++++
 VAD/vad_handler.py                            |  64 ++
 utils.py => VAD/vad_iterator.py               |  18 -
 .../local_audio_streamer.py                   |   0
 connections/socket_receiver.py                |  63 ++
 connections/socket_sender.py                  |  39 ++
 s2s_pipeline.py                               | 587 +-----------------
 utils/thread_manager.py                       |  23 +
 utils/utils.py                                |  18 +
 13 files changed, 660 insertions(+), 589 deletions(-)
 create mode 100644 LLM/language_model.py
 rename LLM/{mlx_lm.py => mlx_language_model.py} (95%)
 create mode 100644 STT/whisper_stt_handler.py
 rename TTS/{melotts.py => melo_handler.py} (98%)
 create mode 100644 TTS/parler_handler.py
 create mode 100644 VAD/vad_handler.py
 rename utils.py => VAD/vad_iterator.py (89%)
 rename local_audio_streamer.py => connections/local_audio_streamer.py (100%)
 create mode 100644 connections/socket_receiver.py
 create mode 100644 connections/socket_sender.py
 create mode 100644 utils/thread_manager.py
 create mode 100644 utils/utils.py

diff --git a/LLM/language_model.py b/LLM/language_model.py
new file mode 100644
index 0000000..5369c73
--- /dev/null
+++ b/LLM/language_model.py
@@ -0,0 +1,134 @@
+from threading import Thread
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    pipeline,
+    TextIteratorStreamer,
+)
+import torch
+
+from LLM.chat import Chat
+from baseHandler import BaseHandler
+from rich.console import Console
+import logging
+from nltk import sent_tokenize
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class LanguageModelHandler(BaseHandler):
+    """
+    Handles the language model part.
+    """
+
+    def setup(
+        self,
+        model_name="microsoft/Phi-3-mini-4k-instruct",
+        device="cuda",
+        torch_dtype="float16",
+        gen_kwargs={},
+        user_role="user",
+        chat_size=1,
+        init_chat_role=None,
+        init_chat_prompt="You are a helpful AI assistant.",
+    ):
+        self.device = device
+        self.torch_dtype = getattr(torch, torch_dtype)
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+        self.model = AutoModelForCausalLM.from_pretrained(
+            model_name, torch_dtype=torch_dtype, trust_remote_code=True
+        ).to(device)
+        self.pipe = pipeline(
+            "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
+        )
+        self.streamer = TextIteratorStreamer(
+            self.tokenizer,
+            skip_prompt=True,
+            skip_special_tokens=True,
+        )
+        self.gen_kwargs = {
+            "streamer": self.streamer,
+            "return_full_text": False,
+            **gen_kwargs,
+        }
+
+        self.chat = Chat(chat_size)
+        if init_chat_role:
+            if not init_chat_prompt:
+                raise ValueError(
+                    "An initial promt needs to be specified when setting init_chat_role."
+                )
+            self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
+        self.user_role = user_role
+
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        dummy_input_text = "Write me a poem about Machine Learning."
+        dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
+        warmup_gen_kwargs = {
+            "min_new_tokens": self.gen_kwargs["max_new_tokens"],
+            "max_new_tokens": self.gen_kwargs["max_new_tokens"],
+            **self.gen_kwargs,
+        }
+
+        n_steps = 2
+
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            torch.cuda.synchronize()
+            start_event.record()
+
+        for _ in range(n_steps):
+            thread = Thread(
+                target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
+            )
+            thread.start()
+            for _ in self.streamer:
+                pass
+
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
+
+    def process(self, prompt):
+        logger.debug("infering language model...")
+
+        self.chat.append({"role": self.user_role, "content": prompt})
+        thread = Thread(
+            target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
+        )
+        thread.start()
+        if self.device == "mps":
+            generated_text = ""
+            for new_text in self.streamer:
+                generated_text += new_text
+            printable_text = generated_text
+            torch.mps.empty_cache()
+        else:
+            generated_text, printable_text = "", ""
+            for new_text in self.streamer:
+                generated_text += new_text
+                printable_text += new_text
+                sentences = sent_tokenize(printable_text)
+                if len(sentences) > 1:
+                    yield (sentences[0])
+                    printable_text = new_text
+
+        self.chat.append({"role": "assistant", "content": generated_text})
+
+        # don't forget last sentence
+        yield printable_text
diff --git a/LLM/mlx_lm.py b/LLM/mlx_language_model.py
similarity index 95%
rename from LLM/mlx_lm.py
rename to LLM/mlx_language_model.py
index 0192afb..65c3b2d 100644
--- a/LLM/mlx_lm.py
+++ b/LLM/mlx_language_model.py
@@ -66,13 +66,15 @@ class MLXLanguageModelHandler(BaseHandler):
         logger.debug("infering language model...")
 
         self.chat.append({"role": self.user_role, "content": prompt})
-        
+
         # Remove system messages if using a Gemma model
         if "gemma" in self.model_name.lower():
-            chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"]
+            chat_messages = [
+                msg for msg in self.chat.to_list() if msg["role"] != "system"
+            ]
         else:
             chat_messages = self.chat.to_list()
-        
+
         prompt = self.tokenizer.apply_chat_template(
             chat_messages, tokenize=False, add_generation_prompt=True
         )
diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py
new file mode 100644
index 0000000..4f13a95
--- /dev/null
+++ b/STT/whisper_stt_handler.py
@@ -0,0 +1,113 @@
+from time import perf_counter
+from transformers import (
+    AutoModelForSpeechSeq2Seq,
+    AutoProcessor,
+)
+import torch
+
+from baseHandler import BaseHandler
+from rich.console import Console
+import logging
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class WhisperSTTHandler(BaseHandler):
+    """
+    Handles the Speech To Text generation using a Whisper model.
+    """
+
+    def setup(
+        self,
+        model_name="distil-whisper/distil-large-v3",
+        device="cuda",
+        torch_dtype="float16",
+        compile_mode=None,
+        gen_kwargs={},
+    ):
+        self.device = device
+        self.torch_dtype = getattr(torch, torch_dtype)
+        self.compile_mode = compile_mode
+        self.gen_kwargs = gen_kwargs
+
+        self.processor = AutoProcessor.from_pretrained(model_name)
+        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
+            model_name,
+            torch_dtype=self.torch_dtype,
+        ).to(device)
+
+        # compile
+        if self.compile_mode:
+            self.model.generation_config.cache_implementation = "static"
+            self.model.forward = torch.compile(
+                self.model.forward, mode=self.compile_mode, fullgraph=True
+            )
+        self.warmup()
+
+    def prepare_model_inputs(self, spoken_prompt):
+        input_features = self.processor(
+            spoken_prompt, sampling_rate=16000, return_tensors="pt"
+        ).input_features
+        input_features = input_features.to(self.device, dtype=self.torch_dtype)
+
+        return input_features
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
+        n_steps = 1 if self.compile_mode == "default" else 2
+        dummy_input = torch.randn(
+            (1, self.model.config.num_mel_bins, 3000),
+            dtype=self.torch_dtype,
+            device=self.device,
+        )
+        if self.compile_mode not in (None, "default"):
+            # generating more tokens than previously will trigger CUDA graphs capture
+            # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
+            warmup_gen_kwargs = {
+                "min_new_tokens": self.gen_kwargs["max_new_tokens"],
+                "max_new_tokens": self.gen_kwargs["max_new_tokens"],
+                **self.gen_kwargs,
+            }
+        else:
+            warmup_gen_kwargs = self.gen_kwargs
+
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+            torch.cuda.synchronize()
+            start_event.record()
+
+        for _ in range(n_steps):
+            _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
+
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
+
+    def process(self, spoken_prompt):
+        logger.debug("infering whisper...")
+
+        global pipeline_start
+        pipeline_start = perf_counter()
+
+        input_features = self.prepare_model_inputs(spoken_prompt)
+        pred_ids = self.model.generate(input_features, **self.gen_kwargs)
+        pred_text = self.processor.batch_decode(
+            pred_ids, skip_special_tokens=True, decode_with_timestamps=False
+        )[0]
+
+        logger.debug("finished whisper inference")
+        console.print(f"[yellow]USER: {pred_text}")
+
+        yield pred_text
diff --git a/TTS/melotts.py b/TTS/melo_handler.py
similarity index 98%
rename from TTS/melotts.py
rename to TTS/melo_handler.py
index fad87d9..9897679 100644
--- a/TTS/melotts.py
+++ b/TTS/melo_handler.py
@@ -24,7 +24,6 @@ class MeloTTSHandler(BaseHandler):
         gen_kwargs={},  # Unused
         blocksize=512,
     ):
-        print(device)
         self.should_listen = should_listen
         self.device = device
         self.model = TTS(language=language, device=device)
diff --git a/TTS/parler_handler.py b/TTS/parler_handler.py
new file mode 100644
index 0000000..efeb5a8
--- /dev/null
+++ b/TTS/parler_handler.py
@@ -0,0 +1,181 @@
+from threading import Thread
+from time import perf_counter
+from baseHandler import BaseHandler
+import numpy as np
+import torch
+from transformers import (
+    AutoTokenizer,
+)
+from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
+import librosa
+import logging
+from rich.console import Console
+from utils.utils import next_power_of_2
+
+torch._inductor.config.fx_graph_cache = True
+# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
+torch._dynamo.config.cache_size_limit = 15
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class ParlerTTSHandler(BaseHandler):
+    def setup(
+        self,
+        should_listen,
+        model_name="ylacombe/parler-tts-mini-jenny-30H",
+        device="cuda",
+        torch_dtype="float16",
+        compile_mode=None,
+        gen_kwargs={},
+        max_prompt_pad_length=8,
+        description=(
+            "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
+            "She speaks very fast."
+        ),
+        play_steps_s=1,
+        blocksize=512,
+    ):
+        self.should_listen = should_listen
+        self.device = device
+        self.torch_dtype = getattr(torch, torch_dtype)
+        self.gen_kwargs = gen_kwargs
+        self.compile_mode = compile_mode
+        self.max_prompt_pad_length = max_prompt_pad_length
+        self.description = description
+
+        self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
+        self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
+        self.model = ParlerTTSForConditionalGeneration.from_pretrained(
+            model_name, torch_dtype=self.torch_dtype
+        ).to(device)
+
+        framerate = self.model.audio_encoder.config.frame_rate
+        self.play_steps = int(framerate * play_steps_s)
+        self.blocksize = blocksize
+
+        if self.compile_mode not in (None, "default"):
+            logger.warning(
+                "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
+            )
+            self.compile_mode = "default"
+
+        if self.compile_mode:
+            self.model.generation_config.cache_implementation = "static"
+            self.model.forward = torch.compile(
+                self.model.forward, mode=self.compile_mode, fullgraph=True
+            )
+
+        self.warmup()
+
+    def prepare_model_inputs(
+        self,
+        prompt,
+        max_length_prompt=50,
+        pad=False,
+    ):
+        pad_args_prompt = (
+            {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
+        )
+
+        tokenized_description = self.description_tokenizer(
+            self.description, return_tensors="pt"
+        )
+        input_ids = tokenized_description.input_ids.to(self.device)
+        attention_mask = tokenized_description.attention_mask.to(self.device)
+
+        tokenized_prompt = self.prompt_tokenizer(
+            prompt, return_tensors="pt", **pad_args_prompt
+        )
+        prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
+        prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
+
+        gen_kwargs = {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "prompt_input_ids": prompt_input_ids,
+            "prompt_attention_mask": prompt_attention_mask,
+            **self.gen_kwargs,
+        }
+
+        return gen_kwargs
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        if self.device == "cuda":
+            start_event = torch.cuda.Event(enable_timing=True)
+            end_event = torch.cuda.Event(enable_timing=True)
+
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
+        n_steps = 1 if self.compile_mode == "default" else 2
+
+        if self.device == "cuda":
+            torch.cuda.synchronize()
+            start_event.record()
+        if self.compile_mode:
+            pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
+            for pad_length in pad_lengths[::-1]:
+                model_kwargs = self.prepare_model_inputs(
+                    "dummy prompt", max_length_prompt=pad_length, pad=True
+                )
+                for _ in range(n_steps):
+                    _ = self.model.generate(**model_kwargs)
+                logger.info(f"Warmed up length {pad_length} tokens!")
+        else:
+            model_kwargs = self.prepare_model_inputs("dummy prompt")
+            for _ in range(n_steps):
+                _ = self.model.generate(**model_kwargs)
+
+        if self.device == "cuda":
+            end_event.record()
+            torch.cuda.synchronize()
+            logger.info(
+                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
+            )
+
+    def process(self, llm_sentence):
+        console.print(f"[green]ASSISTANT: {llm_sentence}")
+        nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
+
+        pad_args = {}
+        if self.compile_mode:
+            # pad to closest upper power of two
+            pad_length = next_power_of_2(nb_tokens)
+            logger.debug(f"padding to {pad_length}")
+            pad_args["pad"] = True
+            pad_args["max_length_prompt"] = pad_length
+
+        tts_gen_kwargs = self.prepare_model_inputs(
+            llm_sentence,
+            **pad_args,
+        )
+
+        streamer = ParlerTTSStreamer(
+            self.model, device=self.device, play_steps=self.play_steps
+        )
+        tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
+        torch.manual_seed(0)
+        thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
+        thread.start()
+
+        for i, audio_chunk in enumerate(streamer):
+            global pipeline_start
+            if i == 0 and "pipeline_start" in globals():
+                logger.info(
+                    f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
+                )
+            audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
+            audio_chunk = (audio_chunk * 32768).astype(np.int16)
+            for i in range(0, len(audio_chunk), self.blocksize):
+                yield np.pad(
+                    audio_chunk[i : i + self.blocksize],
+                    (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
+                )
+
+        self.should_listen.set()
diff --git a/VAD/vad_handler.py b/VAD/vad_handler.py
new file mode 100644
index 0000000..e2da03f
--- /dev/null
+++ b/VAD/vad_handler.py
@@ -0,0 +1,64 @@
+from VAD.vad_iterator import VADIterator
+from baseHandler import BaseHandler
+import numpy as np
+import torch
+from rich.console import Console
+
+from utils.utils import int2float
+
+import logging
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class VADHandler(BaseHandler):
+    """
+    Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
+    to the following part.
+    """
+
+    def setup(
+        self,
+        should_listen,
+        thresh=0.3,
+        sample_rate=16000,
+        min_silence_ms=1000,
+        min_speech_ms=500,
+        max_speech_ms=float("inf"),
+        speech_pad_ms=30,
+    ):
+        self.should_listen = should_listen
+        self.sample_rate = sample_rate
+        self.min_silence_ms = min_silence_ms
+        self.min_speech_ms = min_speech_ms
+        self.max_speech_ms = max_speech_ms
+        self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
+        self.iterator = VADIterator(
+            self.model,
+            threshold=thresh,
+            sampling_rate=sample_rate,
+            min_silence_duration_ms=min_silence_ms,
+            speech_pad_ms=speech_pad_ms,
+        )
+
+    def process(self, audio_chunk):
+        audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
+        audio_float32 = int2float(audio_int16)
+        vad_output = self.iterator(torch.from_numpy(audio_float32))
+        if vad_output is not None and len(vad_output) != 0:
+            logger.debug("VAD: end of speech detected")
+            array = torch.cat(vad_output).cpu().numpy()
+            duration_ms = len(array) / self.sample_rate * 1000
+            if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
+                logger.debug(
+                    f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
+                )
+            else:
+                self.should_listen.clear()
+                logger.debug("Stop listening")
+                yield array
diff --git a/utils.py b/VAD/vad_iterator.py
similarity index 89%
rename from utils.py
rename to VAD/vad_iterator.py
index 3e2e9bc..bd272f1 100644
--- a/utils.py
+++ b/VAD/vad_iterator.py
@@ -1,24 +1,6 @@
-import numpy as np
 import torch
 
 
-def next_power_of_2(x):
-    return 1 if x == 0 else 2 ** (x - 1).bit_length()
-
-
-def int2float(sound):
-    """
-    Taken from https://github.com/snakers4/silero-vad
-    """
-
-    abs_max = np.abs(sound).max()
-    sound = sound.astype("float32")
-    if abs_max > 0:
-        sound *= 1 / 32768
-    sound = sound.squeeze()  # depends on the use case
-    return sound
-
-
 class VADIterator:
     def __init__(
         self,
diff --git a/local_audio_streamer.py b/connections/local_audio_streamer.py
similarity index 100%
rename from local_audio_streamer.py
rename to connections/local_audio_streamer.py
diff --git a/connections/socket_receiver.py b/connections/socket_receiver.py
new file mode 100644
index 0000000..19bda8d
--- /dev/null
+++ b/connections/socket_receiver.py
@@ -0,0 +1,63 @@
+import socket
+from rich.console import Console
+import logging
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class SocketReceiver:
+    """
+    Handles reception of the audio packets from the client.
+    """
+
+    def __init__(
+        self,
+        stop_event,
+        queue_out,
+        should_listen,
+        host="0.0.0.0",
+        port=12345,
+        chunk_size=1024,
+    ):
+        self.stop_event = stop_event
+        self.queue_out = queue_out
+        self.should_listen = should_listen
+        self.chunk_size = chunk_size
+        self.host = host
+        self.port = port
+
+    def receive_full_chunk(self, conn, chunk_size):
+        data = b""
+        while len(data) < chunk_size:
+            packet = conn.recv(chunk_size - len(data))
+            if not packet:
+                # connection closed
+                return None
+            data += packet
+        return data
+
+    def run(self):
+        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        self.socket.bind((self.host, self.port))
+        self.socket.listen(1)
+        logger.info("Receiver waiting to be connected...")
+        self.conn, _ = self.socket.accept()
+        logger.info("receiver connected")
+
+        self.should_listen.set()
+        while not self.stop_event.is_set():
+            audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
+            if audio_chunk is None:
+                # connection closed
+                self.queue_out.put(b"END")
+                break
+            if self.should_listen.is_set():
+                self.queue_out.put(audio_chunk)
+        self.conn.close()
+        logger.info("Receiver closed")
diff --git a/connections/socket_sender.py b/connections/socket_sender.py
new file mode 100644
index 0000000..587343b
--- /dev/null
+++ b/connections/socket_sender.py
@@ -0,0 +1,39 @@
+import socket
+from rich.console import Console
+import logging
+
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+class SocketSender:
+    """
+    Handles sending generated audio packets to the clients.
+    """
+
+    def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
+        self.stop_event = stop_event
+        self.queue_in = queue_in
+        self.host = host
+        self.port = port
+
+    def run(self):
+        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        self.socket.bind((self.host, self.port))
+        self.socket.listen(1)
+        logger.info("Sender waiting to be connected...")
+        self.conn, _ = self.socket.accept()
+        logger.info("sender connected")
+
+        while not self.stop_event.is_set():
+            audio_chunk = self.queue_in.get()
+            self.conn.sendall(audio_chunk)
+            if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
+                break
+        self.conn.close()
+        logger.info("Sender closed")
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 9ed41eb..7a9e204 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -1,43 +1,32 @@
 import logging
 import os
-import socket
 import sys
-import threading
 from copy import copy
 from pathlib import Path
 from queue import Queue
-from threading import Event, Thread
-from time import perf_counter
+from threading import Event
 from typing import Optional
 from sys import platform
+from VAD.vad_handler import VADHandler
 from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
-from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments
+from arguments_classes.mlx_language_model_arguments import (
+    MLXLanguageModelHandlerArguments,
+)
 from arguments_classes.module_arguments import ModuleArguments
 from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
 from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
 from arguments_classes.socket_sender_arguments import SocketSenderArguments
 from arguments_classes.vad_arguments import VADHandlerArguments
 from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
-from baseHandler import BaseHandler
 from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
-import numpy as np
 import torch
 import nltk
-from nltk.tokenize import sent_tokenize
 from rich.console import Console
 from transformers import (
-    AutoModelForCausalLM,
-    AutoModelForSpeechSeq2Seq,
-    AutoProcessor,
-    AutoTokenizer,
     HfArgumentParser,
-    pipeline,
-    TextIteratorStreamer,
 )
-from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
-import librosa
 
-from utils import VADIterator, int2float, next_power_of_2
+from utils.thread_manager import ThreadManager
 
 # Ensure that the necessary NLTK resources are available
 try:
@@ -58,550 +47,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
 console = Console()
 
 
-class ThreadManager:
-    """
-    Manages multiple threads used to execute given handler tasks.
-    """
-
-    def __init__(self, handlers):
-        self.handlers = handlers
-        self.threads = []
-
-    def start(self):
-        for handler in self.handlers:
-            thread = threading.Thread(target=handler.run)
-            self.threads.append(thread)
-            thread.start()
-
-    def stop(self):
-        for handler in self.handlers:
-            handler.stop_event.set()
-        for thread in self.threads:
-            thread.join()
-
-
-class SocketReceiver:
-    """
-    Handles reception of the audio packets from the client.
-    """
-
-    def __init__(
-        self,
-        stop_event,
-        queue_out,
-        should_listen,
-        host="0.0.0.0",
-        port=12345,
-        chunk_size=1024,
-    ):
-        self.stop_event = stop_event
-        self.queue_out = queue_out
-        self.should_listen = should_listen
-        self.chunk_size = chunk_size
-        self.host = host
-        self.port = port
-
-    def receive_full_chunk(self, conn, chunk_size):
-        data = b""
-        while len(data) < chunk_size:
-            packet = conn.recv(chunk_size - len(data))
-            if not packet:
-                # connection closed
-                return None
-            data += packet
-        return data
-
-    def run(self):
-        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.socket.bind((self.host, self.port))
-        self.socket.listen(1)
-        logger.info("Receiver waiting to be connected...")
-        self.conn, _ = self.socket.accept()
-        logger.info("receiver connected")
-
-        self.should_listen.set()
-        while not self.stop_event.is_set():
-            audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
-            if audio_chunk is None:
-                # connection closed
-                self.queue_out.put(b"END")
-                break
-            if self.should_listen.is_set():
-                self.queue_out.put(audio_chunk)
-        self.conn.close()
-        logger.info("Receiver closed")
-
-
-class SocketSender:
-    """
-    Handles sending generated audio packets to the clients.
-    """
-
-    def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
-        self.stop_event = stop_event
-        self.queue_in = queue_in
-        self.host = host
-        self.port = port
-
-    def run(self):
-        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.socket.bind((self.host, self.port))
-        self.socket.listen(1)
-        logger.info("Sender waiting to be connected...")
-        self.conn, _ = self.socket.accept()
-        logger.info("sender connected")
-
-        while not self.stop_event.is_set():
-            audio_chunk = self.queue_in.get()
-            self.conn.sendall(audio_chunk)
-            if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
-                break
-        self.conn.close()
-        logger.info("Sender closed")
-
-
-class VADHandler(BaseHandler):
-    """
-    Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
-    to the following part.
-    """
-
-    def setup(
-        self,
-        should_listen,
-        thresh=0.3,
-        sample_rate=16000,
-        min_silence_ms=1000,
-        min_speech_ms=500,
-        max_speech_ms=float("inf"),
-        speech_pad_ms=30,
-    ):
-        self.should_listen = should_listen
-        self.sample_rate = sample_rate
-        self.min_silence_ms = min_silence_ms
-        self.min_speech_ms = min_speech_ms
-        self.max_speech_ms = max_speech_ms
-        self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
-        self.iterator = VADIterator(
-            self.model,
-            threshold=thresh,
-            sampling_rate=sample_rate,
-            min_silence_duration_ms=min_silence_ms,
-            speech_pad_ms=speech_pad_ms,
-        )
-
-    def process(self, audio_chunk):
-        audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
-        audio_float32 = int2float(audio_int16)
-        vad_output = self.iterator(torch.from_numpy(audio_float32))
-        if vad_output is not None and len(vad_output) != 0:
-            logger.debug("VAD: end of speech detected")
-            array = torch.cat(vad_output).cpu().numpy()
-            duration_ms = len(array) / self.sample_rate * 1000
-            if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
-                logger.debug(
-                    f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
-                )
-            else:
-                self.should_listen.clear()
-                logger.debug("Stop listening")
-                yield array
-
-
-class WhisperSTTHandler(BaseHandler):
-    """
-    Handles the Speech To Text generation using a Whisper model.
-    """
-
-    def setup(
-        self,
-        model_name="distil-whisper/distil-large-v3",
-        device="cuda",
-        torch_dtype="float16",
-        compile_mode=None,
-        gen_kwargs={},
-    ):
-        self.device = device
-        self.torch_dtype = getattr(torch, torch_dtype)
-        self.compile_mode = compile_mode
-        self.gen_kwargs = gen_kwargs
-
-        self.processor = AutoProcessor.from_pretrained(model_name)
-        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
-            model_name,
-            torch_dtype=self.torch_dtype,
-        ).to(device)
-
-        # compile
-        if self.compile_mode:
-            self.model.generation_config.cache_implementation = "static"
-            self.model.forward = torch.compile(
-                self.model.forward, mode=self.compile_mode, fullgraph=True
-            )
-        self.warmup()
-
-    def prepare_model_inputs(self, spoken_prompt):
-        input_features = self.processor(
-            spoken_prompt, sampling_rate=16000, return_tensors="pt"
-        ).input_features
-        input_features = input_features.to(self.device, dtype=self.torch_dtype)
-
-        return input_features
-
-    def warmup(self):
-        logger.info(f"Warming up {self.__class__.__name__}")
-
-        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
-        n_steps = 1 if self.compile_mode == "default" else 2
-        dummy_input = torch.randn(
-            (1, self.model.config.num_mel_bins, 3000),
-            dtype=self.torch_dtype,
-            device=self.device,
-        )
-        if self.compile_mode not in (None, "default"):
-            # generating more tokens than previously will trigger CUDA graphs capture
-            # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
-            warmup_gen_kwargs = {
-                "min_new_tokens": self.gen_kwargs["max_new_tokens"],
-                "max_new_tokens": self.gen_kwargs["max_new_tokens"],
-                **self.gen_kwargs,
-            }
-        else:
-            warmup_gen_kwargs = self.gen_kwargs
-
-        if self.device == "cuda":
-            start_event = torch.cuda.Event(enable_timing=True)
-            end_event = torch.cuda.Event(enable_timing=True)
-            torch.cuda.synchronize()
-            start_event.record()
-
-        for _ in range(n_steps):
-            _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
-
-        if self.device == "cuda":
-            end_event.record()
-            torch.cuda.synchronize()
-
-            logger.info(
-                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
-            )
-
-    def process(self, spoken_prompt):
-        logger.debug("infering whisper...")
-
-        global pipeline_start
-        pipeline_start = perf_counter()
-
-        input_features = self.prepare_model_inputs(spoken_prompt)
-        pred_ids = self.model.generate(input_features, **self.gen_kwargs)
-        pred_text = self.processor.batch_decode(
-            pred_ids, skip_special_tokens=True, decode_with_timestamps=False
-        )[0]
-
-        logger.debug("finished whisper inference")
-        console.print(f"[yellow]USER: {pred_text}")
-
-        yield pred_text
-
-
-class Chat:
-    """
-    Handles the chat using to avoid OOM issues.
-    """
-
-    def __init__(self, size):
-        self.size = size
-        self.init_chat_message = None
-        # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
-        self.buffer = []
-
-    def append(self, item):
-        self.buffer.append(item)
-        if len(self.buffer) == 2 * (self.size + 1):
-            self.buffer.pop(0)
-            self.buffer.pop(0)
-
-    def init_chat(self, init_chat_message):
-        self.init_chat_message = init_chat_message
-
-    def to_list(self):
-        if self.init_chat_message:
-            return [self.init_chat_message] + self.buffer
-        else:
-            return self.buffer
-
-
-class LanguageModelHandler(BaseHandler):
-    """
-    Handles the language model part.
-    """
-
-    def setup(
-        self,
-        model_name="microsoft/Phi-3-mini-4k-instruct",
-        device="cuda",
-        torch_dtype="float16",
-        gen_kwargs={},
-        user_role="user",
-        chat_size=1,
-        init_chat_role=None,
-        init_chat_prompt="You are a helpful AI assistant.",
-    ):
-        self.device = device
-        self.torch_dtype = getattr(torch, torch_dtype)
-
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(
-            model_name, torch_dtype=torch_dtype, trust_remote_code=True
-        ).to(device)
-        self.pipe = pipeline(
-            "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
-        )
-        self.streamer = TextIteratorStreamer(
-            self.tokenizer,
-            skip_prompt=True,
-            skip_special_tokens=True,
-        )
-        self.gen_kwargs = {
-            "streamer": self.streamer,
-            "return_full_text": False,
-            **gen_kwargs,
-        }
-
-        self.chat = Chat(chat_size)
-        if init_chat_role:
-            if not init_chat_prompt:
-                raise ValueError(
-                    "An initial promt needs to be specified when setting init_chat_role."
-                )
-            self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
-        self.user_role = user_role
-
-        self.warmup()
-
-    def warmup(self):
-        logger.info(f"Warming up {self.__class__.__name__}")
-
-        dummy_input_text = "Write me a poem about Machine Learning."
-        dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
-        warmup_gen_kwargs = {
-            "min_new_tokens": self.gen_kwargs["max_new_tokens"],
-            "max_new_tokens": self.gen_kwargs["max_new_tokens"],
-            **self.gen_kwargs,
-        }
-
-        n_steps = 2
-
-        if self.device == "cuda":
-            start_event = torch.cuda.Event(enable_timing=True)
-            end_event = torch.cuda.Event(enable_timing=True)
-            torch.cuda.synchronize()
-            start_event.record()
-
-        for _ in range(n_steps):
-            thread = Thread(
-                target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
-            )
-            thread.start()
-            for _ in self.streamer:
-                pass
-
-        if self.device == "cuda":
-            end_event.record()
-            torch.cuda.synchronize()
-
-            logger.info(
-                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
-            )
-
-    def process(self, prompt):
-        logger.debug("infering language model...")
-
-        self.chat.append({"role": self.user_role, "content": prompt})
-        thread = Thread(
-            target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
-        )
-        thread.start()
-        if self.device == "mps":
-            generated_text = ""
-            for new_text in self.streamer:
-                generated_text += new_text
-            printable_text = generated_text
-            torch.mps.empty_cache()
-        else:
-            generated_text, printable_text = "", ""
-            for new_text in self.streamer:
-                generated_text += new_text
-                printable_text += new_text
-                sentences = sent_tokenize(printable_text)
-                if len(sentences) > 1:
-                    yield (sentences[0])
-                    printable_text = new_text
-
-        self.chat.append({"role": "assistant", "content": generated_text})
-
-        # don't forget last sentence
-        yield printable_text
-
-
-class ParlerTTSHandler(BaseHandler):
-    def setup(
-        self,
-        should_listen,
-        model_name="ylacombe/parler-tts-mini-jenny-30H",
-        device="cuda",
-        torch_dtype="float16",
-        compile_mode=None,
-        gen_kwargs={},
-        max_prompt_pad_length=8,
-        description=(
-            "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
-            "She speaks very fast."
-        ),
-        play_steps_s=1,
-        blocksize=512,
-    ):
-        self.should_listen = should_listen
-        self.device = device
-        self.torch_dtype = getattr(torch, torch_dtype)
-        self.gen_kwargs = gen_kwargs
-        self.compile_mode = compile_mode
-        self.max_prompt_pad_length = max_prompt_pad_length
-        self.description = description
-
-        self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = ParlerTTSForConditionalGeneration.from_pretrained(
-            model_name, torch_dtype=self.torch_dtype
-        ).to(device)
-
-        framerate = self.model.audio_encoder.config.frame_rate
-        self.play_steps = int(framerate * play_steps_s)
-        self.blocksize = blocksize
-
-        if self.compile_mode not in (None, "default"):
-            logger.warning(
-                "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
-            )
-            self.compile_mode = "default"
-
-        if self.compile_mode:
-            self.model.generation_config.cache_implementation = "static"
-            self.model.forward = torch.compile(
-                self.model.forward, mode=self.compile_mode, fullgraph=True
-            )
-
-        self.warmup()
-
-    def prepare_model_inputs(
-        self,
-        prompt,
-        max_length_prompt=50,
-        pad=False,
-    ):
-        pad_args_prompt = (
-            {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
-        )
-
-        tokenized_description = self.description_tokenizer(
-            self.description, return_tensors="pt"
-        )
-        input_ids = tokenized_description.input_ids.to(self.device)
-        attention_mask = tokenized_description.attention_mask.to(self.device)
-
-        tokenized_prompt = self.prompt_tokenizer(
-            prompt, return_tensors="pt", **pad_args_prompt
-        )
-        prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
-        prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
-
-        gen_kwargs = {
-            "input_ids": input_ids,
-            "attention_mask": attention_mask,
-            "prompt_input_ids": prompt_input_ids,
-            "prompt_attention_mask": prompt_attention_mask,
-            **self.gen_kwargs,
-        }
-
-        return gen_kwargs
-
-    def warmup(self):
-        logger.info(f"Warming up {self.__class__.__name__}")
-
-        if self.device == "cuda":
-            start_event = torch.cuda.Event(enable_timing=True)
-            end_event = torch.cuda.Event(enable_timing=True)
-
-        # 2 warmup steps for no compile or compile mode with CUDA graphs capture
-        n_steps = 1 if self.compile_mode == "default" else 2
-
-        if self.device == "cuda":
-            torch.cuda.synchronize()
-            start_event.record()
-        if self.compile_mode:
-            pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
-            for pad_length in pad_lengths[::-1]:
-                model_kwargs = self.prepare_model_inputs(
-                    "dummy prompt", max_length_prompt=pad_length, pad=True
-                )
-                for _ in range(n_steps):
-                    _ = self.model.generate(**model_kwargs)
-                logger.info(f"Warmed up length {pad_length} tokens!")
-        else:
-            model_kwargs = self.prepare_model_inputs("dummy prompt")
-            for _ in range(n_steps):
-                _ = self.model.generate(**model_kwargs)
-
-        if self.device == "cuda":
-            end_event.record()
-            torch.cuda.synchronize()
-            logger.info(
-                f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
-            )
-
-    def process(self, llm_sentence):
-        console.print(f"[green]ASSISTANT: {llm_sentence}")
-        nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
-
-        pad_args = {}
-        if self.compile_mode:
-            # pad to closest upper power of two
-            pad_length = next_power_of_2(nb_tokens)
-            logger.debug(f"padding to {pad_length}")
-            pad_args["pad"] = True
-            pad_args["max_length_prompt"] = pad_length
-
-        tts_gen_kwargs = self.prepare_model_inputs(
-            llm_sentence,
-            **pad_args,
-        )
-
-        streamer = ParlerTTSStreamer(
-            self.model, device=self.device, play_steps=self.play_steps
-        )
-        tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
-        torch.manual_seed(0)
-        thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
-        thread.start()
-
-        for i, audio_chunk in enumerate(streamer):
-            if i == 0 and "pipeline_start" in globals():
-                logger.info(
-                    f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
-                )
-            audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
-            audio_chunk = (audio_chunk * 32768).astype(np.int16)
-            for i in range(0, len(audio_chunk), self.blocksize):
-                yield np.pad(
-                    audio_chunk[i : i + self.blocksize],
-                    (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
-                )
-
-        self.should_listen.set()
-
-
 def prepare_args(args, prefix):
     """
     Rename arguments by removing the prefix and prepares the gen_kwargs.
@@ -745,7 +190,7 @@ def main():
     lm_response_queue = Queue()
 
     if module_kwargs.mode == "local":
-        from local_audio_streamer import LocalAudioStreamer
+        from connections.local_audio_streamer import LocalAudioStreamer
 
         local_audio_streamer = LocalAudioStreamer(
             input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
@@ -753,6 +198,9 @@ def main():
         comms_handlers = [local_audio_streamer]
         should_listen.set()
     else:
+        from connections.socket_receiver import SocketReceiver
+        from connections.socket_sender import SocketSender
+
         comms_handlers = [
             SocketReceiver(
                 stop_event,
@@ -778,6 +226,8 @@ def main():
         setup_kwargs=vars(vad_handler_kwargs),
     )
     if module_kwargs.stt == "whisper":
+        from STT.whisper_stt_handler import WhisperSTTHandler
+
         stt = WhisperSTTHandler(
             stop_event,
             queue_in=spoken_prompt_queue,
@@ -786,6 +236,7 @@ def main():
         )
     elif module_kwargs.stt == "whisper-mlx":
         from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
+
         stt = LightningWhisperSTTHandler(
             stop_event,
             queue_in=spoken_prompt_queue,
@@ -795,6 +246,8 @@ def main():
     else:
         raise ValueError("The STT should be either whisper or whisper-mlx")
     if module_kwargs.llm == "transformers":
+        from LLM.language_model import LanguageModelHandler
+
         lm = LanguageModelHandler(
             stop_event,
             queue_in=text_prompt_queue,
@@ -802,7 +255,8 @@ def main():
             setup_kwargs=vars(language_model_handler_kwargs),
         )
     elif module_kwargs.llm == "mlx-lm":
-        from LLM.mlx_lm import MLXLanguageModelHandler
+        from LLM.mlx_language_model import MLXLanguageModelHandler
+
         lm = MLXLanguageModelHandler(
             stop_event,
             queue_in=text_prompt_queue,
@@ -812,9 +266,8 @@ def main():
     else:
         raise ValueError("The LLM should be either transformers or mlx-lm")
     if module_kwargs.tts == "parler":
-        torch._inductor.config.fx_graph_cache = True
-        # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
-        torch._dynamo.config.cache_size_limit = 15
+        from TTS.parler_handler import ParlerTTSHandler
+
         tts = ParlerTTSHandler(
             stop_event,
             queue_in=lm_response_queue,
@@ -825,7 +278,7 @@ def main():
 
     elif module_kwargs.tts == "melo":
         try:
-            from TTS.melotts import MeloTTSHandler
+            from TTS.melo_handler import MeloTTSHandler
         except RuntimeError as e:
             logger.error(
                 "Error importing MeloTTSHandler. You might need to run: python -m unidic download"
diff --git a/utils/thread_manager.py b/utils/thread_manager.py
new file mode 100644
index 0000000..fc1ca4a
--- /dev/null
+++ b/utils/thread_manager.py
@@ -0,0 +1,23 @@
+import threading
+
+
+class ThreadManager:
+    """
+    Manages multiple threads used to execute given handler tasks.
+    """
+
+    def __init__(self, handlers):
+        self.handlers = handlers
+        self.threads = []
+
+    def start(self):
+        for handler in self.handlers:
+            thread = threading.Thread(target=handler.run)
+            self.threads.append(thread)
+            thread.start()
+
+    def stop(self):
+        for handler in self.handlers:
+            handler.stop_event.set()
+        for thread in self.threads:
+            thread.join()
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..ac39948
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,18 @@
+import numpy as np
+
+
+def next_power_of_2(x):
+    return 1 if x == 0 else 2 ** (x - 1).bit_length()
+
+
+def int2float(sound):
+    """
+    Taken from https://github.com/snakers4/silero-vad
+    """
+
+    abs_max = np.abs(sound).max()
+    sound = sound.astype("float32")
+    if abs_max > 0:
+        sound *= 1 / 32768
+    sound = sound.squeeze()  # depends on the use case
+    return sound
-- 
GitLab