Skip to content
Snippets Groups Projects
Commit 060f5527 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

improve clarity

parent d5e46072
No related branches found
No related tags found
No related merge requests found
...@@ -3,43 +3,43 @@ from dataclasses import dataclass, field ...@@ -3,43 +3,43 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ParlerTTSHandlerArguments: class ParlerTTSHandlerArguments:
tts_model_name: str = field( parler_model_name: str = field(
default="ylacombe/parler-tts-mini-jenny-30H", default="ylacombe/parler-tts-mini-jenny-30H",
metadata={ metadata={
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
}, },
) )
tts_device: str = field( parler_device: str = field(
default="cuda", default="cuda",
metadata={ metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}, },
) )
tts_torch_dtype: str = field( parler_torch_dtype: str = field(
default="float16", default="float16",
metadata={ metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}, },
) )
tts_compile_mode: str = field( parler_compile_mode: str = field(
default=None, default=None,
metadata={ metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
}, },
) )
tts_gen_min_new_tokens: int = field( parler_gen_min_new_tokens: int = field(
default=64, default=64,
metadata={ metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
}, },
) )
tts_gen_max_new_tokens: int = field( parler_gen_max_new_tokens: int = field(
default=512, default=512,
metadata={ metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
}, },
) )
description: str = field( parler_description: str = field(
default=( default=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast." "She speaks very fast."
...@@ -48,13 +48,13 @@ class ParlerTTSHandlerArguments: ...@@ -48,13 +48,13 @@ class ParlerTTSHandlerArguments:
"help": "Description of the speaker's voice and speaking style to guide the TTS model." "help": "Description of the speaker's voice and speaking style to guide the TTS model."
}, },
) )
play_steps_s: float = field( parler_play_steps_s: float = field(
default=1.0, default=1.0,
metadata={ metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
}, },
) )
max_prompt_pad_length: int = field( parler_max_prompt_pad_length: int = field(
default=8, default=8,
metadata={ metadata={
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
......
...@@ -2,68 +2,68 @@ from dataclasses import dataclass, field ...@@ -2,68 +2,68 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class LanguageModelHandlerArguments: class TransformersLanguageModelHandlerArguments:
lm_model_name: str = field( transformers_lm_model_name: str = field(
default="HuggingFaceTB/SmolLM-360M-Instruct", default="HuggingFaceTB/SmolLM-360M-Instruct",
metadata={ metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
}, },
) )
lm_device: str = field( transformers_lm_device: str = field(
default="cuda", default="cuda",
metadata={ metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}, },
) )
lm_torch_dtype: str = field( transformers_lm_torch_dtype: str = field(
default="float16", default="float16",
metadata={ metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}, },
) )
user_role: str = field( transformers_lm_user_role: str = field(
default="user", default="user",
metadata={ metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'." "help": "Role assigned to the user in the chat context. Default is 'user'."
}, },
) )
init_chat_role: str = field( transformers_lm_init_chat_role: str = field(
default="system", default="system",
metadata={ metadata={
"help": "Initial role for setting up the chat context. Default is 'system'." "help": "Initial role for setting up the chat context. Default is 'system'."
}, },
) )
init_chat_prompt: str = field( transformers_lm_init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={ metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
}, },
) )
lm_gen_max_new_tokens: int = field( transformers_lm_gen_max_new_tokens: int = field(
default=128, default=128,
metadata={ metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128." "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
}, },
) )
lm_gen_min_new_tokens: int = field( transformers_lm_gen_min_new_tokens: int = field(
default=0, default=0,
metadata={ metadata={
"help": "Minimum number of new tokens to generate in a single completion. Default is 0." "help": "Minimum number of new tokens to generate in a single completion. Default is 0."
}, },
) )
lm_gen_temperature: float = field( transformers_lm_gen_temperature: float = field(
default=0.0, default=0.0,
metadata={ metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
}, },
) )
lm_gen_do_sample: bool = field( transformers_lm_gen_do_sample: bool = field(
default=False, default=False,
metadata={ metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
}, },
) )
chat_size: int = field( transformers_lm_chat_size: int = field(
default=2, default=2,
metadata={ metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations." "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
......
...@@ -4,49 +4,49 @@ from typing import Optional ...@@ -4,49 +4,49 @@ from typing import Optional
@dataclass @dataclass
class WhisperSTTHandlerArguments: class WhisperSTTHandlerArguments:
stt_model_name: str = field( whisper_model_name: str = field(
default="distil-whisper/distil-large-v3", default="distil-whisper/distil-large-v3",
metadata={ metadata={
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
}, },
) )
stt_device: str = field( whisper_device: str = field(
default="cuda", default="cuda",
metadata={ metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
}, },
) )
stt_torch_dtype: str = field( whisper_torch_dtype: str = field(
default="float16", default="float16",
metadata={ metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
}, },
) )
stt_compile_mode: str = field( whisper_compile_mode: str = field(
default=None, default=None,
metadata={ metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
}, },
) )
stt_gen_max_new_tokens: int = field( whisper_gen_max_new_tokens: int = field(
default=128, default=128,
metadata={ metadata={
"help": "The maximum number of new tokens to generate. Default is 128." "help": "The maximum number of new tokens to generate. Default is 128."
}, },
) )
stt_gen_num_beams: int = field( whisper_gen_num_beams: int = field(
default=1, default=1,
metadata={ metadata={
"help": "The number of beams for beam search. Default is 1, implying greedy decoding." "help": "The number of beams for beam search. Default is 1, implying greedy decoding."
}, },
) )
stt_gen_return_timestamps: bool = field( whisper_gen_return_timestamps: bool = field(
default=False, default=False,
metadata={ metadata={
"help": "Whether to return timestamps with transcriptions. Default is False." "help": "Whether to return timestamps with transcriptions. Default is False."
}, },
) )
stt_gen_task: str = field( whisper_gen_task: str = field(
default="transcribe", default="transcribe",
metadata={ metadata={
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
......
...@@ -9,7 +9,7 @@ from typing import Optional ...@@ -9,7 +9,7 @@ from typing import Optional
from sys import platform from sys import platform
from VAD.vad_handler import VADHandler from VAD.vad_handler import VADHandler
from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.transformers_language_model_arguments import TransformersLanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import ( from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
) )
...@@ -76,7 +76,7 @@ def parse_arguments(): ...@@ -76,7 +76,7 @@ def parse_arguments():
VADHandlerArguments, VADHandlerArguments,
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments, ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments, TransformersLanguageModelHandlerArguments,
OpenApiLanguageModelHandlerArguments, OpenApiLanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
...@@ -161,7 +161,7 @@ def prepare_all_args( ...@@ -161,7 +161,7 @@ def prepare_all_args(
module_kwargs, module_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -172,7 +172,7 @@ def prepare_all_args( ...@@ -172,7 +172,7 @@ def prepare_all_args(
module_kwargs, module_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -181,12 +181,12 @@ def prepare_all_args( ...@@ -181,12 +181,12 @@ def prepare_all_args(
) )
rename_args(whisper_stt_handler_kwargs, "stt") rename_args(whisper_stt_handler_kwargs, "whisper")
rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(language_model_handler_kwargs, "lm") rename_args(transformers_language_model_handler_kwargs, "transformers_lm")
rename_args(mlx_language_model_handler_kwargs, "mlx_lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(open_api_language_model_handler_kwargs, "open_api") rename_args(open_api_language_model_handler_kwargs, "open_api")
rename_args(parler_tts_handler_kwargs, "tts") rename_args(parler_tts_handler_kwargs, "parler")
rename_args(melo_tts_handler_kwargs, "melo") rename_args(melo_tts_handler_kwargs, "melo")
rename_args(chat_tts_handler_kwargs, "chat_tts") rename_args(chat_tts_handler_kwargs, "chat_tts")
...@@ -210,7 +210,7 @@ def build_pipeline( ...@@ -210,7 +210,7 @@ def build_pipeline(
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -262,14 +262,45 @@ def build_pipeline( ...@@ -262,14 +262,45 @@ def build_pipeline(
setup_kwargs=vars(vad_handler_kwargs), setup_kwargs=vars(vad_handler_kwargs),
) )
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) stt = get_stt_handler(
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) module_kwargs,
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) stop_event,
spoken_prompt_queue,
text_prompt_queue,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs
)
lm = get_llm_handler(
module_kwargs,
stop_event,
text_prompt_queue,
lm_response_queue,
transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs
)
tts = get_tts_handler(
module_kwargs,
stop_event,
lm_response_queue,
send_audio_chunks_queue,
should_listen,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs
)
return ThreadManager([*comms_handlers, vad, stt, lm, tts]) return ThreadManager([*comms_handlers, vad, stt, lm, tts])
def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): def get_stt_handler(
module_kwargs,
stop_event,
spoken_prompt_queue,
text_prompt_queue,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs
):
if module_kwargs.stt == "whisper": if module_kwargs.stt == "whisper":
from STT.whisper_stt_handler import WhisperSTTHandler from STT.whisper_stt_handler import WhisperSTTHandler
return WhisperSTTHandler( return WhisperSTTHandler(
...@@ -337,7 +368,16 @@ def get_llm_handler( ...@@ -337,7 +368,16 @@ def get_llm_handler(
raise ValueError("The LLM should be either transformers or mlx-lm") raise ValueError("The LLM should be either transformers or mlx-lm")
def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): def get_tts_handler(
module_kwargs,
stop_event,
lm_response_queue,
send_audio_chunks_queue,
should_listen,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs
):
if module_kwargs.tts == "parler": if module_kwargs.tts == "parler":
from TTS.parler_handler import ParlerTTSHandler from TTS.parler_handler import ParlerTTSHandler
return ParlerTTSHandler( return ParlerTTSHandler(
...@@ -387,7 +427,7 @@ def main(): ...@@ -387,7 +427,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -401,14 +441,14 @@ def main(): ...@@ -401,14 +441,14 @@ def main():
module_kwargs, module_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
chat_tts_handler_kwargs, chat_tts_handler_kwargs,
) )
queues_and_events = initialize_queues_and_events() queues_and_events = initialize_queues_and_events()
pipeline_manager = build_pipeline( pipeline_manager = build_pipeline(
...@@ -418,7 +458,7 @@ def main(): ...@@ -418,7 +458,7 @@ def main():
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs, paraformer_stt_handler_kwargs,
language_model_handler_kwargs, transformers_language_model_handler_kwargs,
open_api_language_model_handler_kwargs, open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment