diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index 1b744ff85d3bd9e768c16cce5b236a949025aefb..f7af261c0afde3bd187eaaec48e4680f0badc8d0 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -1,38 +1,39 @@
 import logging
+import os
 import socket
+import sys
 import threading
-from threading import Thread, Event
+from collections import deque
+from copy import copy
+from dataclasses import dataclass, field
+from pathlib import Path
 from queue import Queue
+from threading import Event, Thread
 from time import perf_counter
-import sys
-import os
-from pathlib import Path
-from dataclasses import dataclass, field
-from copy import copy
-from collections import deque
 
 import numpy as np
 import torch
 from nltk.tokenize import sent_tokenize
 from rich.console import Console
 from transformers import (
-    AutoModelForCausalLM, 
-    AutoModelForSpeechSeq2Seq, 
-    AutoProcessor, 
-    AutoTokenizer, 
-    pipeline, 
-    TextIteratorStreamer,
-    HfArgumentParser
+    AutoModelForCausalLM,
+    AutoModelForSpeechSeq2Seq,
+    AutoProcessor,
+    AutoTokenizer,
+    HfArgumentParser,
+    pipeline,
+    TextIteratorStreamer
 )
+
 from parler_tts import (
     ParlerTTSForConditionalGeneration,
-    ParlerTTSStreamer,
+    ParlerTTSStreamer
 )
 
 from utils import (
-    VADIterator, 
+    VADIterator,
     int2float,
-    next_power_of_2,
+    next_power_of_2
 )
 
 
@@ -44,8 +45,10 @@ torch._inductor.config.fx_graph_cache = True
 # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
 torch._dynamo.config.cache_size_limit = 15
 
+
 console = Console()
 
+
 @dataclass
 class ModuleArguments:
     log_level: str = field(
@@ -55,7 +58,12 @@ class ModuleArguments:
         }
     )
 
+
 class ThreadManager:
+    """
+    Manages multiple threads used to execute given handler tasks.
+    """
+
     def __init__(self, handlers):
         self.handlers = handlers
         self.threads = []
@@ -72,7 +80,16 @@ class ThreadManager:
         for thread in self.threads:
             thread.join()
 
+
 class BaseHandler:
+    """
+    Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
+    The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
+    To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
+    Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
+    The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
+    """
+
     def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
         self.stop_event = stop_event
         self.queue_in = queue_in
@@ -135,6 +152,10 @@ class SocketReceiverArguments:
 
 
 class SocketReceiver:
+    """
+    Handles reception of the audio packets from the client.
+    """
+
     def __init__(
         self, 
         stop_event,
@@ -201,6 +222,10 @@ class SocketSenderArguments:
 
             
 class SocketSender:
+    """
+    Handles sending generated audio packets to the clients.
+    """
+
     def __init__(
         self, 
         stop_event,
@@ -273,6 +298,11 @@ class VADHandlerArguments:
 
 
 class VADHandler(BaseHandler):
+    """
+    Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
+    to the following part.
+    """
+
     def setup(
             self, 
             should_listen,
@@ -284,11 +314,11 @@ class VADHandler(BaseHandler):
             speech_pad_ms=30,
 
         ):
-        self._should_listen = should_listen
-        self._sample_rate = sample_rate
-        self._min_silence_ms = min_silence_ms
-        self._min_speech_ms = min_speech_ms
-        self._max_speech_ms = max_speech_ms
+        self.should_listen = should_listen
+        self.sample_rate = sample_rate
+        self.min_silence_ms = min_silence_ms
+        self.min_speech_ms = min_speech_ms
+        self.max_speech_ms = max_speech_ms
         self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad')
         self.iterator = VADIterator(
             self.model,
@@ -305,8 +335,8 @@ class VADHandler(BaseHandler):
         if vad_output is not None and len(vad_output) != 0:
             logger.debug("VAD: end of speech detected")
             array = torch.cat(vad_output).cpu().numpy()
-            duration_ms = len(array) / self._sample_rate * 1000
-            if duration_ms < self._min_speech_ms or duration_ms > self._max_speech_ms:
+            duration_ms = len(array) / self.sample_rate * 1000
+            if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
                 logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping")
             else:
                 self._should_listen.clear()
@@ -373,6 +403,10 @@ class WhisperSTTHandlerArguments:
 
 
 class WhisperSTTHandler(BaseHandler):
+    """
+    Handles the Speech To Text generation using a Whisper model.
+    """
+
     def setup(
             self,
             model_name="distil-whisper/distil-large-v3",
@@ -381,16 +415,17 @@ class WhisperSTTHandler(BaseHandler):
             compile_mode=None,
             gen_kwargs={}
         ): 
-        self.compile_mode=compile_mode
-        self.processor = AutoProcessor.from_pretrained(model_name)
         self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
+        self.compile_mode=compile_mode
+        self.gen_kwargs = gen_kwargs
+
+        self.processor = AutoProcessor.from_pretrained(model_name)
         self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
             model_name,
             torch_dtype=self.torch_dtype,
         ).to(device)
-        self.gen_kwargs = gen_kwargs
-
+        
         # compile
         if self.compile_mode:
             self.model.generation_config.cache_implementation = "static"
@@ -402,20 +437,19 @@ class WhisperSTTHandler(BaseHandler):
             spoken_prompt, sampling_rate=16000, return_tensors="pt"
         ).input_features
         input_features = input_features.to(self.device, dtype=self.torch_dtype)
+
         return input_features
         
     def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
         # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
         n_steps = 1 if self.compile_mode == "default" else 2
-        logger.info(f"Warming up {self.__class__.__name__}")
         dummy_input = torch.randn(
             (1,  self.model.config.num_mel_bins, 3000),
             dtype=self.torch_dtype,
             device=self.device
         ) 
-        start_event = torch.cuda.Event(enable_timing=True)
-        end_event = torch.cuda.Event(enable_timing=True)
-        torch.cuda.synchronize()
         if self.compile_mode not in (None, "default"):
             # generating more tokens than previously will trigger CUDA graphs capture
             # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
@@ -427,28 +461,35 @@ class WhisperSTTHandler(BaseHandler):
         else:
             warmup_gen_kwargs = self.gen_kwargs
 
+        start_event = torch.cuda.Event(enable_timing=True)
+        end_event = torch.cuda.Event(enable_timing=True)
+
+        torch.cuda.synchronize()
         start_event.record()
         for _ in range(n_steps):
             _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
         end_event.record()
         torch.cuda.synchronize()
+
         logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
 
     def process(self, spoken_prompt):
+        logger.debug("infering whisper...")
+
         global pipeline_start
         pipeline_start = perf_counter()
-        input_features = self.processor(
-            spoken_prompt, sampling_rate=16000, return_tensors="pt"
-        ).input_features
-        input_features = input_features.to(self.device, dtype=self.torch_dtype)
-        logger.debug("infering whisper...")
+
+        input_features = self.prepare_model_inputs(spoken_prompt)
         pred_ids = self.model.generate(input_features, **self.gen_kwargs)
         pred_text = self.processor.batch_decode(
-            pred_ids, skip_special_tokens=True,
+            pred_ids, 
+            skip_special_tokens=True,
             decode_with_timestamps=False
         )[0]
+
         logger.debug("finished whisper inference")
         console.print(f"[yellow]USER: {pred_text}")
+
         yield pred_text
 
 
@@ -509,6 +550,10 @@ class LanguageModelHandlerArguments:
 
 
 class Chat:
+    """
+    Handles the chat using a circular buffer to avoid OOM issues.
+    """
+
     def __init__(self, size):
         self.init_chat_message = None
         self.buffer = deque(maxlen=size)
@@ -527,25 +572,30 @@ class Chat:
 
 
 class LanguageModelHandler(BaseHandler):
+    """
+    Handles the language model part. 
+    """
+
     def setup(
             self,
             model_name="microsoft/Phi-3-mini-4k-instruct",
             device="cuda", 
             torch_dtype="float16",
-            chat_size=3,
             gen_kwargs={},
             user_role="user",
+            chat_size=3,
             init_chat_role=None, 
             init_chat_prompt="You are a helpful AI assistant.",
         ):
+        self.device = device
         self.torch_dtype = getattr(torch, torch_dtype)
+
         self.tokenizer = AutoTokenizer.from_pretrained(model_name)
         self.model = AutoModelForCausalLM.from_pretrained(
             model_name,
             torch_dtype=torch_dtype,
             trust_remote_code=True
         ).to(device)
-        self.device = device
         self.pipe = pipeline( 
             "text-generation", 
             model=self.model, 
@@ -556,6 +606,12 @@ class LanguageModelHandler(BaseHandler):
             skip_prompt=True,
             skip_special_tokens=True,
         )
+        self.gen_kwargs = {
+            "streamer": self.streamer,
+            "return_full_text": False,
+            **gen_kwargs
+        }
+
         self.chat = Chat(chat_size)
         if init_chat_role:
             if not init_chat_prompt:
@@ -563,26 +619,12 @@ class LanguageModelHandler(BaseHandler):
             self.chat.init_chat(
                 {"role": init_chat_role, "content": init_chat_prompt}
             )
-
-        self.gen_kwargs = {
-            "streamer": self.streamer,
-            "return_full_text": False,
-            **gen_kwargs
-        }
         self.user_role = user_role
 
-        
-
-
         self.warmup()
 
     def warmup(self):
-        # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
-        n_steps = 2
         logger.info(f"Warming up {self.__class__.__name__}")
-        start_event = torch.cuda.Event(enable_timing=True)
-        end_event = torch.cuda.Event(enable_timing=True)
-        torch.cuda.synchronize()
 
         dummy_input_text = "Write me a poem about Machine Learning."
         dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
@@ -592,25 +634,33 @@ class LanguageModelHandler(BaseHandler):
             **self.gen_kwargs
         }
 
+        n_steps = 2
+
+        start_event = torch.cuda.Event(enable_timing=True)
+        end_event = torch.cuda.Event(enable_timing=True)
+
+        torch.cuda.synchronize()
         start_event.record()
         for _ in range(n_steps):
             thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs)
             thread.start()
             for _ in self.streamer: 
-                pass
-                
+                pass    
         end_event.record()
         torch.cuda.synchronize()
+
         logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
 
     def process(self, prompt):
+        logger.debug("infering language model...")
+
         self.chat.append(
             {"role": self.user_role, "content": prompt}
         )
         thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs)
         thread.start()
+
         generated_text, printable_text = "", ""
-        logger.debug("infering language model...")
         for new_text in self.streamer:
             generated_text += new_text
             printable_text += new_text
@@ -618,9 +668,11 @@ class LanguageModelHandler(BaseHandler):
             if len(sentences) > 1:
                 yield(sentences[0])
                 printable_text = new_text
+
         self.chat.append(
             {"role": "assistant", "content": generated_text}
         )
+
         # don't forget last sentence
         yield printable_text
 
@@ -689,37 +741,33 @@ class ParlerTTSHandler(BaseHandler):
             model_name="ylacombe/parler-tts-mini-jenny-30H",
             device="cuda", 
             torch_dtype="float16",
-            max_prompt_pad_length=8,
-            gen_kwargs={},
             compile_mode=None,
+            gen_kwargs={},
+            max_prompt_pad_length=8,
             description=(
                 "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
                 "She speaks very fast."
             ),
             play_steps_s=1
         ):
+        self.should_listen = should_listen
+        self.device = device
+        self.torch_dtype = getattr(torch, torch_dtype)
+        self.gen_kwargs = gen_kwargs
+        self.compile_mode = compile_mode
         self.max_prompt_pad_length = max_prompt_pad_length
-        torch_dtype = getattr(torch, torch_dtype)
-        self._should_listen = should_listen
+        self.description = description
+
         self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
         self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
         self.model = ParlerTTSForConditionalGeneration.from_pretrained(
             model_name,
-            torch_dtype=torch_dtype
+            torch_dtype=self.torch_dtype
         ).to(device)
-        self.device = device
-        self.torch_dtype = torch_dtype
-
-        self.description = description
-        self.gen_kwargs = gen_kwargs
-        
-        framerate = self.model.audio_encoder.config.frame_rate
-        self.play_steps = int(framerate * play_steps_s)
         
         framerate = self.model.audio_encoder.config.frame_rate
         self.play_steps = int(framerate * play_steps_s)
 
-        self.compile_mode = compile_mode
         if self.compile_mode not in (None, "default"):
             logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'")
             self.compile_mode = "default"
@@ -727,6 +775,7 @@ class ParlerTTSHandler(BaseHandler):
         if self.compile_mode:
             self.model.generation_config.cache_implementation = "static"
             self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
+
         self.warmup()
 
     def prepare_model_inputs(
@@ -752,29 +801,40 @@ class ParlerTTSHandler(BaseHandler):
             "prompt_attention_mask": prompt_attention_mask,
             **self.gen_kwargs
         }
+
         return gen_kwargs
     
     def warmup(self):
-        pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
-        for pad_length in pad_lengths[::-1]:
-            model_kwargs = self.prepare_model_inputs(
-                "dummy prompt", 
-                max_length_prompt=pad_length,
-                pad=True
-            )
-            # 2 warmup steps for modes that capture CUDA graphs
-            n_steps = 1 if self.compile_mode == "default" else 2
-
-            logger.info(f"Warming up length {pad_length} tokens...")
-            start_event = torch.cuda.Event(enable_timing=True)
-            end_event = torch.cuda.Event(enable_timing=True)
-            torch.cuda.synchronize()
-            start_event.record()
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        start_event = torch.cuda.Event(enable_timing=True)
+        end_event = torch.cuda.Event(enable_timing=True)
+
+        # 2 warmup steps for no compile or compile mode with CUDA graphs capture 
+        n_steps = 1 if self.compile_mode == "default" else 2
+
+        torch.cuda.synchronize()
+        start_event.record()
+        if self.compile_mode:
+            pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
+            for pad_length in pad_lengths[::-1]:
+                model_kwargs = self.prepare_model_inputs(
+                    "dummy prompt", 
+                    max_length_prompt=pad_length,
+                    pad=True
+                )
+                for _ in range(n_steps):
+                    _ = self.model.generate(**model_kwargs)
+                logger.info(f"Warmed up length {pad_length} tokens!")
+        else:
+            model_kwargs = self.prepare_model_inputs("dummy prompt")
             for _ in range(n_steps):
-                _ = self.model.generate(**model_kwargs)
-            end_event.record()
-            torch.cuda.synchronize()
-            logger.info(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+                    _ = self.model.generate(**model_kwargs)
+                
+        end_event.record() 
+        torch.cuda.synchronize()
+        logger.info(f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
+
 
     def process(self, llm_sentence):
         console.print(f"[green]ASSISTANT: {llm_sentence}")
@@ -808,10 +868,14 @@ class ParlerTTSHandler(BaseHandler):
             audio_chunk = np.int16(audio_chunk * 32767)
             yield audio_chunk
 
-        self._should_listen.set()
+        self.should_listen.set()
 
 
 def prepare_args(args, prefix):
+    """
+    Rename arguments by removing the prefix and prepares the gen_kwargs.
+    """
+
     gen_kwargs = {}
     for key in copy(args.__dict__):
         if key.startswith(prefix):
@@ -860,6 +924,7 @@ def main():
             parler_tts_handler_kwargs,
         ) = parser.parse_args_into_dataclasses()
 
+    # 1. Handle logger
     global logger
     logging.basicConfig(
         level=module_kwargs.log_level.upper(),
@@ -871,12 +936,15 @@ def main():
     if module_kwargs.log_level == "debug":
         torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
 
+    # 2. Prepare each part's arguments
     prepare_args(whisper_stt_handler_kwargs, "stt")
     prepare_args(language_model_handler_kwargs, "lm")
     prepare_args(parler_tts_handler_kwargs, "tts") 
 
+    # 3. Build the pipeline
     stop_event = Event()
-    should_listen = Event()
+    # used to stop putting received audio chunks in queue until all setences have been processed by the TTS
+    should_listen = Event() 
     recv_audio_chunks_queue = Queue()
     send_audio_chunks_queue = Queue()
     spoken_prompt_queue = Queue() 
@@ -926,6 +994,7 @@ def main():
         port=socket_sender_kwargs.send_port,
         )
 
+    # 4. Run the pipeline
     try:
         pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler])
         pipeline_manager.start()