diff --git a/LLM/language_model.py b/LLM/language_model.py index ddeb34b1e6895a6ffd77a9f734cb17ad50a1c3a0..3fb332974d64509de8070e21d7f290fff66197a9 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -25,6 +25,7 @@ WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { "zh": "chinese", "ja": "japanese", "ko": "korean", + "hi": "hindi", } class LanguageModelHandler(BaseHandler): diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 09300879e1dea3f790e3db40349b1da2b9675888..682862146f8c47beeb24f480ddcc7ae4967a2fd7 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -19,6 +19,7 @@ SUPPORTED_LANGUAGES = [ "zh", "ja", "ko", + "hi" ] diff --git a/TTS/facebookmms_handler.py b/TTS/facebookmms_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..a370582050d5b03ca8afee653ab3b3bf18857b71 --- /dev/null +++ b/TTS/facebookmms_handler.py @@ -0,0 +1,128 @@ +from transformers import VitsModel, AutoTokenizer +import torch +import numpy as np +import librosa +from rich.console import Console +from baseHandler import BaseHandler +import logging + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.DEBUG +) +logger = logging.getLogger(__name__) + +console = Console() + +WHISPER_LANGUAGE_TO_FACEBOOK_LANGUAGE = { + "<|en|>": "eng", + "<|fr|>": "fra", + "<|es|>": "spa", + "<|ko|>": "kor", + "<|hi|>": "hin", +} + +class FacebookMMSTTSHandler(BaseHandler): + def setup( + self, + should_listen, + facebook_mms_device="cuda", + facebook_mms_torch_dtype="float32", + language="en", + stream=True, + chunk_size=512, + **kwargs + ): + self.should_listen = should_listen + self.device = facebook_mms_device + self.torch_dtype = getattr(torch, facebook_mms_torch_dtype) + self.stream = stream + self.chunk_size = chunk_size + self.language = "<|" + language + "|>" + + self.load_model(self.language) + + def load_model(self, language_id): + model_name = f"facebook/mms-tts-{WHISPER_LANGUAGE_TO_FACEBOOK_LANGUAGE[language_id]}" + logger.info(f"Loading model: {model_name}") + self.model = VitsModel.from_pretrained(model_name).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.language = language_id + + def generate_audio(self, text): + if not text: + logger.warning("Received empty text input") + return None + + try: + logger.debug(f"Tokenizing text: {text}") + logger.debug(f"Current language: {self.language}") + logger.debug(f"Tokenizer: {self.tokenizer}") + + inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) + input_ids = inputs.input_ids.to(self.device).long() + attention_mask = inputs.attention_mask.to(self.device) + + logger.debug(f"Input IDs shape: {input_ids.shape}, dtype: {input_ids.dtype}") + logger.debug(f"Input IDs: {input_ids}") + + if input_ids.numel() == 0: + logger.error("Input IDs tensor is empty") + return None + + with torch.no_grad(): + output = self.model(input_ids=input_ids, attention_mask=attention_mask) + + logger.debug(f"Output waveform shape: {output.waveform.shape}") + return output.waveform + except Exception as e: + logger.error(f"Error in generate_audio: {str(e)}") + logger.exception("Full traceback:") + return None + + def process(self, llm_sentence): + language_id = None + + if isinstance(llm_sentence, tuple): + llm_sentence, language_id = llm_sentence + + console.print(f"[green]ASSISTANT: {llm_sentence}") + logger.debug(f"Processing text: {llm_sentence}") + logger.debug(f"Language ID: {language_id}") + + if language_id is not None and self.language != language_id: + try: + logger.info(f"Switching language from {self.language} to {language_id}") + self.load_model(language_id) + except KeyError: + console.print(f"[red]Language {language_id} not supported by Facebook MMS. Using {self.language} instead.") + logger.warning(f"Unsupported language: {language_id}") + + audio_output = self.generate_audio(llm_sentence) + + if audio_output is None or audio_output.numel() == 0: + logger.warning("No audio output generated") + self.should_listen.set() + return + + audio_numpy = audio_output.cpu().numpy().squeeze() + logger.debug(f"Raw audio shape: {audio_numpy.shape}, dtype: {audio_numpy.dtype}") + + audio_resampled = librosa.resample(audio_numpy, orig_sr=self.model.config.sampling_rate, target_sr=16000) + logger.debug(f"Resampled audio shape: {audio_resampled.shape}, dtype: {audio_resampled.dtype}") + + audio_int16 = (audio_resampled * 32768).astype(np.int16) + logger.debug(f"Final audio shape: {audio_int16.shape}, dtype: {audio_int16.dtype}") + + if self.stream: + for i in range(0, len(audio_int16), self.chunk_size): + chunk = audio_int16[i:i + self.chunk_size] + yield np.pad(chunk, (0, self.chunk_size - len(chunk))) + else: + for i in range(0, len(audio_int16), self.chunk_size): + yield np.pad( + audio_int16[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_int16[i : i + self.chunk_size])), + ) + + self.should_listen.set() \ No newline at end of file diff --git a/arguments_classes/facebookmms_tts_arguments.py b/arguments_classes/facebookmms_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc0196cf02f93e24ad95c3ff7373dbeb0c57a37 --- /dev/null +++ b/arguments_classes/facebookmms_tts_arguments.py @@ -0,0 +1,29 @@ +from dataclasses import field, dataclass + +@dataclass +class FacebookMMSTTSHandlerArguments: + model_name: str = field( + default="facebook/mms-tts-hin", + metadata={ + "help": "The model name to use. Default is 'facebook/mms-tts-hin'." + }, + ) + tts_language: str = field( # Renamed to avoid conflict + default="en", + metadata={ + "help": "The language code for the TTS model. Default is 'en' for English." + }, + ) + facebook_mms_device: str = field( + default="cuda", + metadata={ + "help": "The device to use for the TTS model. Default is 'cuda'." + }, + ) + facebook_mms_torch_dtype: str = field( + default="float32", + metadata={ + "help": "The torch data type to use for the TTS model. Default is 'float32'." + }, + ) + diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index bdaa646b4578439e47f7e36104d168910b2fe602..efdec973f6f92edea0fcb3b6ee59afde3ad1a244 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -35,7 +35,7 @@ class ModuleArguments: tts: Optional[str] = field( default="parler", metadata={ - "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" + "help": "The TTS to use. Either 'parler', 'melo', 'chatTTS' or 'facebookmms'. Default is 'parler'" }, ) log_level: str = field( diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index 5dc700bf24e2320d0065ab6db40c0adbcf4782b5..719cf4dcad469ebd25167549c4528ebc6c40d5aa 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -57,7 +57,7 @@ class WhisperSTTHandlerArguments: metadata={ "help": """The language for the conversation. Choose between 'en' (english), 'fr' (french), 'es' (spanish), - 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'. + 'zh' (chinese), 'ko' (korean), 'ja' (japanese), 'hi' (hindi) or 'None'. If using 'auto', the language is automatically detected and can change during the conversation. Default is 'en'.""" }, diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1da202e200c825ad8473f1e115d41cd5f8f686ff..19c7f70d5f86bbd3cbbbb3f98c8f1e9034836ca8 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -22,6 +22,7 @@ from arguments_classes.vad_arguments import VADHandlerArguments from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments +from arguments_classes.facebookmms_tts_arguments import FacebookMMSTTSHandlerArguments import torch import nltk from rich.console import Console @@ -82,6 +83,7 @@ def parse_arguments(): ParlerTTSHandlerArguments, MeloTTSHandlerArguments, ChatTTSHandlerArguments, + FacebookMMSTTSHandlerArguments, ) ) @@ -167,6 +169,7 @@ def prepare_all_args( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, ): prepare_module_args( module_kwargs, @@ -178,9 +181,9 @@ def prepare_all_args( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, ) - rename_args(whisper_stt_handler_kwargs, "stt") rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(language_model_handler_kwargs, "lm") @@ -189,6 +192,7 @@ def prepare_all_args( rename_args(parler_tts_handler_kwargs, "tts") rename_args(melo_tts_handler_kwargs, "melo") rename_args(chat_tts_handler_kwargs, "chat_tts") + rename_args(facebook_mms_tts_handler_kwargs, "facebook_mms") def initialize_queues_and_events(): @@ -216,6 +220,7 @@ def build_pipeline( parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, queues_and_events, ): stop_event = queues_and_events["stop_event"] @@ -264,7 +269,7 @@ def build_pipeline( stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) - tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) + tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs) return ThreadManager([*comms_handlers, vad, stt, lm, tts]) @@ -337,7 +342,7 @@ def get_llm_handler( raise ValueError("The LLM should be either transformers or mlx-lm") -def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): +def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs): if module_kwargs.tts == "parler": from TTS.parler_handler import ParlerTTSHandler return ParlerTTSHandler( @@ -375,6 +380,15 @@ def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chu setup_args=(should_listen,), setup_kwargs=vars(chat_tts_handler_kwargs), ) + elif module_kwargs.tts == "facebookmms": + from TTS.facebookmms_handler import FacebookMMSTTSHandler + tts = FacebookMMSTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(facebook_mms_tts_handler_kwargs), + ) else: raise ValueError("The TTS should be either parler, melo or chatTTS") @@ -393,6 +407,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, ) = parse_arguments() setup_logger(module_kwargs.log_level) @@ -407,6 +422,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, ) queues_and_events = initialize_queues_and_events() @@ -424,6 +440,7 @@ def main(): parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, + facebook_mms_tts_handler_kwargs, queues_and_events, )