-
Andres Marafioti authoredAndres Marafioti authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
s2s_pipeline.py 35.02 KiB
import logging
import os
import socket
import sys
import threading
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue
from threading import Event, Thread
from time import perf_counter
from typing import Optional
from LLM.mlx_lm import MLXLanguageModelHandler
from TTS.melotts import MeloTTSHandler
from baseHandler import BaseHandler
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
import numpy as np
import torch
import nltk
from nltk.tokenize import sent_tokenize
from rich.console import Console
from transformers import (
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoTokenizer,
HfArgumentParser,
pipeline,
TextIteratorStreamer,
)
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
import librosa
from local_audio_streamer import LocalAudioStreamer
from utils import VADIterator, int2float, next_power_of_2
# Ensure that the necessary NLTK resources are available
try:
nltk.data.find('tokenizers/punkt_tab')
except (LookupError, OSError):
nltk.download('punkt_tab')
# caching allows ~50% compilation time reduction
# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
CURRENT_DIR = Path(__file__).resolve().parent
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
console = Console()
@dataclass
class ModuleArguments:
device: Optional[str] = field(
default=None,
metadata={"help": "If specified, overrides the device for all handlers."},
)
mode: Optional[str] = field(
default="local",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
},
)
tts: Optional[str] = field(
default="parler",
metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
},
)
log_level: str = field(
default="info",
metadata={
"help": "Provide logging level. Example --log_level debug, default=warning."
},
)
class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
"""
def __init__(self, handlers):
self.handlers = handlers
self.threads = []
def start(self):
for handler in self.handlers:
thread = threading.Thread(target=handler.run)
self.threads.append(thread)
thread.start()
def stop(self):
for handler in self.handlers:
handler.stop_event.set()
for thread in self.threads:
thread.join()
@dataclass
class SocketReceiverArguments:
recv_host: str = field(
default="localhost",
metadata={
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
recv_port: int = field(
default=12345,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)
chunk_size: int = field(
default=1024,
metadata={
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
},
)
class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""
def __init__(
self,
stop_event,
queue_out,
should_listen,
host="0.0.0.0",
port=12345,
chunk_size=1024,
):
self.stop_event = stop_event
self.queue_out = queue_out
self.should_listen = should_listen
self.chunk_size = chunk_size
self.host = host
self.port = port
def receive_full_chunk(self, conn, chunk_size):
data = b""
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
# connection closed
return None
data += packet
return data
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Receiver waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("receiver connected")
self.should_listen.set()
while not self.stop_event.is_set():
audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
if audio_chunk is None:
# connection closed
self.queue_out.put(b"END")
break
if self.should_listen.is_set():
self.queue_out.put(audio_chunk)
self.conn.close()
logger.info("Receiver closed")
@dataclass
class SocketSenderArguments:
send_host: str = field(
default="localhost",
metadata={
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
send_port: int = field(
default=12346,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)
class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""
def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
self.host = host
self.port = port
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Sender waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("sender connected")
while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
self.conn.close()
logger.info("Sender closed")
@dataclass
class VADHandlerArguments:
thresh: float = field(
default=0.3,
metadata={
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
},
)
sample_rate: int = field(
default=16000,
metadata={
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
},
)
min_silence_ms: int = field(
default=250,
metadata={
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
},
)
min_speech_ms: int = field(
default=500,
metadata={
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
},
)
max_speech_ms: float = field(
default=float("inf"),
metadata={
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
},
)
speech_pad_ms: int = field(
default=250,
metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
},
)
class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
to the following part.
"""
def setup(
self,
should_listen,
thresh=0.3,
sample_rate=16000,
min_silence_ms=1000,
min_speech_ms=500,
max_speech_ms=float("inf"),
speech_pad_ms=30,
):
self.should_listen = should_listen
self.sample_rate = sample_rate
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.max_speech_ms = max_speech_ms
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
self.iterator = VADIterator(
self.model,
threshold=thresh,
sampling_rate=sample_rate,
min_silence_duration_ms=min_silence_ms,
speech_pad_ms=speech_pad_ms,
)
def process(self, audio_chunk):
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
audio_float32 = int2float(audio_int16)
vad_output = self.iterator(torch.from_numpy(audio_float32))
if vad_output is not None and len(vad_output) != 0:
logger.debug("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / self.sample_rate * 1000
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
logger.debug(
f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
)
else:
self.should_listen.clear()
logger.debug("Stop listening")
yield array
@dataclass
class WhisperSTTHandlerArguments:
stt_model_name: str = field(
default="distil-whisper/distil-large-v3",
metadata={
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
},
)
stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
stt_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
stt_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
stt_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "The maximum number of new tokens to generate. Default is 128."
},
)
stt_gen_num_beams: int = field(
default=1,
metadata={
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
},
)
stt_gen_return_timestamps: bool = field(
default=False,
metadata={
"help": "Whether to return timestamps with transcriptions. Default is False."
},
)
# stt_gen_task: str = field(
# default="transcribe",
# metadata={
# "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
# },
# )
# stt_gen_language: str = field(
# default="en",
# metadata={
# "help": "The language of the speech to transcribe. Default is 'en' for English."
# },
# )
class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup(
self,
model_name="distil-whisper/distil-large-v3",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
).to(device)
# compile
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(self, spoken_prompt):
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device,
)
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
else:
warmup_gen_kwargs = self.gen_kwargs
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, spoken_prompt):
logger.debug("infering whisper...")
global pipeline_start
pipeline_start = perf_counter()
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text
@dataclass
class LanguageModelHandlerArguments:
lm_model_name: str = field(
default="HuggingFaceTB/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
init_chat_role: str = field(
default='system',
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
class Chat:
"""
Handles the chat using to avoid OOM issues.
"""
def __init__(self, size):
self.size = size
self.init_chat_message = None
# maxlen is necessary pair, since a each new step we add an prompt and assitant answer
self.buffer = []
def append(self, item):
self.buffer.append(item)
if len(self.buffer) == 2 * (self.size + 1):
self.buffer.pop(0)
self.buffer.pop(0)
def init_chat(self, init_chat_message):
self.init_chat_message = init_chat_message
def to_list(self):
if self.init_chat_message:
return [self.init_chat_message] + self.buffer
else:
return self.buffer
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda",
torch_dtype="float16",
gen_kwargs={},
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
self.pipe = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
)
self.streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs,
}
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
n_steps = 2
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
thread = Thread(
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
)
thread.start()
for _ in self.streamer:
pass
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
)
thread.start()
if self.device == "mps":
generated_text = ""
for new_text in self.streamer:
generated_text += new_text
printable_text = generated_text
torch.mps.empty_cache()
else:
generated_text, printable_text = "", ""
for new_text in self.streamer:
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
@dataclass
class ParlerTTSHandlerArguments:
tts_model_name: str = field(
default="ylacombe/parler-tts-mini-jenny-30H",
metadata={
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
},
)
tts_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
tts_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
tts_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
tts_gen_min_new_tokens: int = field(
default=64,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
},
)
tts_gen_max_new_tokens: int = field(
default=512,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
},
)
description: str = field(
default=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
metadata={
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
},
)
play_steps_s: float = field(
default=1.0,
metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
},
)
max_prompt_pad_length: int = field(
default=8,
metadata={
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
},
)
class ParlerTTSHandler(BaseHandler):
def setup(
self,
should_listen,
model_name="ylacombe/parler-tts-mini-jenny-30H",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
max_prompt_pad_length=8,
description=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
play_steps_s=1,
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.gen_kwargs = gen_kwargs
self.compile_mode = compile_mode
self.max_prompt_pad_length = max_prompt_pad_length
self.description = description
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name, torch_dtype=self.torch_dtype
).to(device)
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
self.blocksize = blocksize
if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
)
self.compile_mode = "default"
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(
self,
prompt,
max_length_prompt=50,
pad=False,
):
pad_args_prompt = (
{"padding": "max_length", "max_length": max_length_prompt} if pad else {}
)
tokenized_description = self.description_tokenizer(
self.description, return_tensors="pt"
)
input_ids = tokenized_description.input_ids.to(self.device)
attention_mask = tokenized_description.attention_mask.to(self.device)
tokenized_prompt = self.prompt_tokenizer(
prompt, return_tensors="pt", **pad_args_prompt
)
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
**self.gen_kwargs,
}
return gen_kwargs
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
if self.device == "cuda":
torch.cuda.synchronize()
start_event.record()
if self.compile_mode:
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs(
"dummy prompt", max_length_prompt=pad_length, pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
pad_args = {}
if self.compile_mode:
# pad to closest upper power of two
pad_length = next_power_of_2(nb_tokens)
logger.debug(f"padding to {pad_length}")
pad_args["pad"] = True
pad_args["max_length_prompt"] = pad_length
tts_gen_kwargs = self.prepare_model_inputs(
llm_sentence,
**pad_args,
)
streamer = ParlerTTSStreamer(
self.model, device=self.device, play_steps=self.play_steps
)
tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
torch.manual_seed(0)
thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
thread.start()
for i, audio_chunk in enumerate(streamer):
if i == 0 and 'pipeline_start' in globals():
logger.info(
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
self.should_listen.set()
def prepare_args(args, prefix):
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
"""
gen_kwargs = {}
for key in copy(args.__dict__):
if key.startswith(prefix):
value = args.__dict__.pop(key)
new_key = key[len(prefix) + 1 :] # Remove prefix and underscore
if new_key.startswith("gen_"):
gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict
else:
args.__dict__[new_key] = value
args.__dict__["gen_kwargs"] = gen_kwargs
def main():
parser = HfArgumentParser(
(
ModuleArguments,
SocketReceiverArguments,
SocketSenderArguments,
VADHandlerArguments,
WhisperSTTHandlerArguments,
LanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
)
)
# 0. Parse CLI arguments
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# Parse configurations from a JSON file if specified
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
language_model_handler_kwargs,
parler_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
# Parse arguments from command line if no JSON file is provided
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
language_model_handler_kwargs,
parler_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
global logger
logging.basicConfig(
level=module_kwargs.log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# torch compile logs
if module_kwargs.log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
# 2. Prepare each part's arguments
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
# Call this function with the common device and all the handlers
overwrite_device_argument(
module_kwargs.device,
language_model_handler_kwargs,
parler_tts_handler_kwargs,
whisper_stt_handler_kwargs,
)
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "lm")
prepare_args(parler_tts_handler_kwargs, "tts")
# 3. Build the pipeline
stop_event = Event()
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS
should_listen = Event()
recv_audio_chunks_queue = Queue()
send_audio_chunks_queue = Queue()
spoken_prompt_queue = Queue()
text_prompt_queue = Queue()
lm_response_queue = Queue()
if module_kwargs.mode == "local":
local_audio_streamer = LocalAudioStreamer(
input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
)
comms_handlers = [local_audio_streamer]
should_listen.set()
else:
comms_handlers = [
SocketReceiver(
stop_event,
recv_audio_chunks_queue,
should_listen,
host=socket_receiver_kwargs.recv_host,
port=socket_receiver_kwargs.recv_port,
chunk_size=socket_receiver_kwargs.chunk_size,
),
SocketSender(
stop_event,
send_audio_chunks_queue,
host=socket_sender_kwargs.send_host,
port=socket_sender_kwargs.send_port,
),
]
vad = VADHandler(
stop_event,
queue_in=recv_audio_chunks_queue,
queue_out=spoken_prompt_queue,
setup_args=(should_listen,),
setup_kwargs=vars(vad_handler_kwargs),
)
stt = LightningWhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
if module_kwargs.tts == 'parler':
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15
tts = ParlerTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
)
elif module_kwargs.tts == 'melo':
tts = MeloTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
)
else:
raise ValueError("The TTS should be either parler or melo")
# 4. Run the pipeline
try:
pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
pipeline_manager.start()
except KeyboardInterrupt:
pipeline_manager.stop()
if __name__ == "__main__":
main()