import logging import socket import threading from threading import Thread, Event from queue import Queue from time import perf_counter import sys import os from dataclasses import dataclass, field from copy import copy import multiprocessing import numpy as np import soundfile as sf import torch from nltk.tokenize import sent_tokenize from rich.console import Console from transformers import ( AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline, TextIteratorStreamer, HfArgumentParser ) from parler_tts import ( ParlerTTSForConditionalGeneration, ParlerTTSStreamer, ) from utils import ( VADIterator, int2float, ) console = Console() @dataclass class ModuleArguments: log_level: str = field( default="info", metadata={ "help": "Provide logging level. Example --log_level debug, default=warning." } ) class ThreadManager: def __init__(self, handlers): self.handlers = handlers self.threads = [] def start(self): for handler in self.handlers: thread = threading.Thread(target=handler.run) self.threads.append(thread) thread.start() def stop(self): for handler in self.handlers: handler.stop_event.set() for thread in self.threads: thread.join() pipeline_start = None class BaseHandler: def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): self.stop_event = stop_event self.queue_in = queue_in self.queue_out = queue_out self.setup(*setup_args, **setup_kwargs) self._times = [] def setup(self): pass def process(self): raise NotImplementedError def run(self): while not self.stop_event.is_set(): input = self.queue_in.get() if isinstance(input, bytes) and input == b'END': # sentinelle signal to avoid queue deadlock logger.debug("Stopping thread") break start_time = perf_counter() for output in self.process(input): self._times.append(perf_counter() - start_time) logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") self.queue_out.put(output) start_time = perf_counter() self.cleanup() self.queue_out.put(b'END') @property def last_time(self): return self._times[-1] def cleanup(self): pass @dataclass class SocketReceiverArguments: recv_host: str = field( default="localhost", metadata={ "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " "available interfaces on the host machine." } ) recv_port: int = field( default=12345, metadata={ "help": "The port number on which the socket server listens. Default is 12346." } ) chunk_size: int = field( default=1024, metadata={ "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." } ) class SocketReceiver: def __init__( self, stop_event, queue_out, should_listen, host='0.0.0.0', port=12345, chunk_size=1024 ): self.stop_event = stop_event self.queue_out = queue_out self.should_listen = should_listen self.chunk_size=chunk_size self.host = host self.port = port def receive_full_chunk(self, conn, chunk_size): data = b'' while len(data) < chunk_size: packet = conn.recv(chunk_size - len(data)) if not packet: # connection closed return None data += packet return data def run(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) self.conn, _ = self.socket.accept() logger.debug("receiver connected") self.should_listen.set() while not self.stop_event.is_set(): audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) if audio_chunk is None: # connection closed self.queue_out.put(b'END') break if self.should_listen.is_set(): self.queue_out.put(audio_chunk) self.conn.close() logger.debug("Receiver closed") @dataclass class SocketSenderArguments: send_host: str = field( default="localhost", metadata={ "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " "available interfaces on the host machine." } ) send_port: int = field( default=12346, metadata={ "help": "The port number on which the socket server listens. Default is 12346." } ) class SocketSender: def __init__( self, stop_event, queue_in, host='0.0.0.0', port=12346 ): self.stop_event = stop_event self.queue_in = queue_in self.host = host self.port = port def run(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) self.conn, _ = self.socket.accept() logger.debug("sender connected") while not self.stop_event.is_set(): audio_chunk = self.queue_in.get() self.conn.sendall(audio_chunk) if isinstance(audio_chunk, bytes) and audio_chunk == b'END': break self.conn.close() logger.debug("Sender closed") @dataclass class VADHandlerArguments: thresh: float = field( default=0.3, metadata={ "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." } ) sample_rate: int = field( default=16000, metadata={ "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." } ) min_silence_ms: int = field( default=1000, metadata={ "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 1000 ms." } ) min_speech_ms: int = field( default=500, metadata={ "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." } ) max_speech_ms: float = field( default=float('inf'), metadata={ "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." } ) speech_pad_ms: int = field( default=30, metadata={ "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 30 ms." } ) class VADHandler(BaseHandler): def setup( self, should_listen, thresh=0.3, sample_rate=16000, min_silence_ms=1000, min_speech_ms=500, max_speech_ms=float('inf'), speech_pad_ms=30, ): self._should_listen = should_listen self._sample_rate = sample_rate self._min_silence_ms = min_silence_ms self._min_speech_ms = min_speech_ms self._max_speech_ms = max_speech_ms self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad') self.iterator = VADIterator( self.model, threshold=thresh, sampling_rate=sample_rate, min_silence_duration_ms=min_silence_ms, speech_pad_ms=speech_pad_ms, ) def process(self, audio_chunk): audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) audio_float32 = int2float(audio_int16) vad_output = self.iterator(torch.from_numpy(audio_float32)) if vad_output is not None: logger.debug("VAD: end of speech detected") array = torch.cat(vad_output).cpu().numpy() duration_ms = len(array) / self._sample_rate * 1000 if duration_ms < self._min_speech_ms or duration_ms > self._max_speech_ms: logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping") else: self._should_listen.clear() logger.debug("Stop listening") yield array @dataclass class WhisperSTTHandlerArguments: stt_model_name: str = field( default="distil-whisper/distil-large-v3", metadata={ "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." } ) stt_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." } ) stt_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." } ) stt_gen_max_new_tokens: int = field( default=128, metadata={ "help": "The maximum number of new tokens to generate. Default is 128." } ) stt_gen_num_beams: int = field( default=1, metadata={ "help": "The number of beams for beam search. Default is 1, implying greedy decoding." } ) stt_gen_return_timestamps: bool = field( default=False, metadata={ "help": "Whether to return timestamps with transcriptions. Default is False." } ) stt_gen_task: str = field( default="transcribe", metadata={ "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." } ) stt_gen_language: str = field( default="en", metadata={ "help": "The language of the speech to transcribe. Default is 'en' for English." } ) class WhisperSTTHandler(BaseHandler): def setup( self, model_name="distil-whisper/distil-large-v3", device="cuda", torch_dtype="float16", gen_kwargs={} ): self.processor = AutoProcessor.from_pretrained(model_name) self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=self.torch_dtype, ).to(device) self.gen_kwargs = gen_kwargs def process(self, spoken_prompt): global pipeline_start pipeline_start = perf_counter() input_features = self.processor( spoken_prompt, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(self.device, dtype=self.torch_dtype) logger.debug("infering whisper...") pred_ids = self.model.generate(input_features, **self.gen_kwargs) pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") yield pred_text @dataclass class LanguageModelHandlerArguments: llm_model_name: str = field( default="microsoft/Phi-3-mini-4k-instruct", metadata={ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." } ) llm_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." } ) llm_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." } ) user_role: str = field( default="user", metadata={ "help": "Role assigned to the user in the chat context. Default is 'user'." } ) init_chat_role: str = field( default="system", metadata={ "help": "Initial role for setting up the chat context. Default is 'system'." } ) init_chat_prompt: str = field( default="You are a helpful AI assistant.", metadata={ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" } ) llm_gen_max_new_tokens: int = field( default=128, metadata={"help": "Maximum number of new tokens to generate in a single completion. Default is 128."} ) llm_gen_temperature: float = field( default=0.0, metadata={"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."} ) llm_gen_do_sample: bool = field( default=False, metadata={"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."} ) class LanguageModelHandler(BaseHandler): def setup( self, model_name="microsoft/Phi-3-mini-4k-instruct", device="cuda", torch_dtype="float16", gen_kwargs={}, user_role="user", init_chat_role="system", init_chat_prompt="You are a helpful AI assistant.", ): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) self.pipe = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, ) self.streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) self.chat = [ {"role": init_chat_role, "content": init_chat_prompt} ] self.gen_kwargs = { "streamer": self.streamer, "return_full_text": False, **gen_kwargs } self.user_role = user_role def process(self, prompt): self.chat.append( {"role": self.user_role, "content": prompt} ) thread = Thread(target=self.pipe, args=(self.chat,), kwargs=self.gen_kwargs) thread.start() generated_text, printable_text = "", "" logger.debug("infering language model...") for new_text in self.streamer: generated_text += new_text printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: yield(sentences[0]) printable_text = new_text self.chat.append( {"role": "assistant", "content": generated_text} ) # don't forget last sentence yield printable_text @dataclass class ParlerTTSHandlerArguments: tts_model_name: str = field( default="ylacombe/parler-tts-mini-jenny-30H", metadata={ "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." } ) tts_device: str = field( default="cuda", metadata={ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." } ) tts_torch_dtype: str = field( default="float16", metadata={ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." } ) gen_kwargs: dict = field( default_factory=dict, metadata={ "help": "Additional keyword arguments to pass to the model's generate method. Use this to customize generation settings." } ) description: str = field( default=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "She speaks very fast." ), metadata={ "help": "Description of the speaker's voice and speaking style to guide the TTS model." } ) play_steps_s: float = field( default=0.5, metadata={ "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." } ) class ParlerTTSHandler(BaseHandler): def setup( self, should_listen, model_name="ylacombe/parler-tts-mini-jenny-30H", device="cuda", torch_dtype="float16", gen_kwargs={}, description=( "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "She speaks very fast." ), play_steps_s=0.5 ): torch_dtype = getattr(torch, torch_dtype) self._should_listen = should_listen self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = ParlerTTSForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype ).to(device) self.device = device self.torch_dtype = torch_dtype tokenized_description = self.description_tokenizer(description, return_tensors="pt") input_ids = tokenized_description.input_ids.to(self.device) attention_mask = tokenized_description.attention_mask.to(self.device) self.gen_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, **gen_kwargs } framerate = self.model.audio_encoder.config.frame_rate self.play_steps = int(framerate * play_steps_s) def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") tokenized_prompt = self.prompt_tokenizer(llm_sentence, return_tensors="pt") prompt_input_ids = tokenized_prompt.input_ids.to(self.device) prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) streamer = ParlerTTSStreamer(self.model, device=self.device, play_steps=self.play_steps) tts_gen_kwargs = { "prompt_input_ids": prompt_input_ids, "prompt_attention_mask": prompt_attention_mask, "streamer": streamer, **self.gen_kwargs } torch.manual_seed(0) thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) thread.start() for i, audio_chunk in enumerate(streamer): if i == 0: logger.debug(f"time to first audio: {perf_counter() - pipeline_start:.3f}") audio_chunk = np.int16(audio_chunk * 32767) yield audio_chunk self._should_listen.set() def prepare_args(args, prefix): gen_kwargs = {} for key in copy(args.__dict__): if key.startswith(prefix): value = args.__dict__.pop(key) new_key = key[len(prefix) + 1:] # Remove prefix and underscore if new_key.startswith("gen_"): gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict else: args.__dict__[new_key] = value args.__dict__["gen_kwargs"] = gen_kwargs def main(): parser = HfArgumentParser(( ModuleArguments, SocketReceiverArguments, SocketSenderArguments, VADHandlerArguments, WhisperSTTHandlerArguments, LanguageModelHandlerArguments, ParlerTTSHandlerArguments, )) # 0. Parse CLI arguments if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Parse configurations from a JSON file if specified ( module_kwargs, socket_receiver_kwargs, socket_sender_kwargs, vad_handler_kwargs, whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided ( module_kwargs, socket_receiver_kwargs, socket_sender_kwargs, vad_handler_kwargs, whisper_stt_handler_kwargs, language_model_handler_kwargs, parler_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() global logger logging.basicConfig( level=module_kwargs.log_level.upper(), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', ) logger = logging.getLogger(__name__) prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(language_model_handler_kwargs, "llm") prepare_args(parler_tts_handler_kwargs, "tts") if not module_kwargs.client and not module_kwargs.server: # is equivalent as behaving as both client and server module_kwargs.client = True module_kwargs.server = True stop_event = Event() should_listen = Event() recv_audio_chunks_queue = Queue() send_audio_chunks_queue = Queue() spoken_prompt_queue = Queue() text_prompt_queue = Queue() llm_response_queue = Queue() vad = VADHandler( stop_event, queue_in=recv_audio_chunks_queue, queue_out=spoken_prompt_queue, setup_args=(should_listen,), setup_kwargs=vars(vad_handler_kwargs), ) stt = WhisperSTTHandler( stop_event, queue_in=spoken_prompt_queue, queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) llm = LanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=llm_response_queue, setup_kwargs=vars(language_model_handler_kwargs), ) tts = ParlerTTSHandler( stop_event, queue_in=llm_response_queue, queue_out=send_audio_chunks_queue, setup_args=(should_listen,), setup_kwargs=vars(parler_tts_handler_kwargs), ) recv_handler = SocketReceiver( stop_event, recv_audio_chunks_queue, should_listen, host=socket_receiver_kwargs.recv_host, port=socket_receiver_kwargs.recv_port, chunk_size=socket_receiver_kwargs.chunk_size, ) send_handler = SocketSender( stop_event, send_audio_chunks_queue, host=socket_sender_kwargs.send_host, port=socket_sender_kwargs.send_port, ) try: pipeline_manager = ThreadManager([vad, tts, llm, stt, recv_handler, send_handler]) pipeline_manager.start() except KeyboardInterrupt: pipeline_manager.stop() if __name__ == "__main__": main()