diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index a55800c2d75d5cd9ad20dbb853ea556f3b97fbf2..38dc5910770a15393d93e7ca03a52a83c21d74f3 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -8,7 +8,6 @@ import torch from baseHandler import BaseHandler from rich.console import Console import logging -from shared_variables import current_language logger = logging.getLogger(__name__) console = Console() @@ -109,10 +108,8 @@ class WhisperSTTHandler(BaseHandler): )[0] language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) - print("WHISPER curr lang", language_id) - logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") - console.print(f"[red]Language ID Whisper: {language_id}") + logger.debug(f"Language ID Whisper: {language_id}") yield (pred_text, language_id) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index f05545a31d246e7ac0af813115c00d1316ccb621..64f371d363f4943a8a1a304422cd11249b11ef52 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -5,7 +5,6 @@ import librosa import numpy as np from rich.console import Console import torch -from shared_variables import current_language logger = logging.getLogger(__name__) @@ -42,9 +41,15 @@ class MeloTTSHandler(BaseHandler): ): self.should_listen = should_listen self.device = device - self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations - self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device) - self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]] + self.language = ( + "<|" + language + "|>" + ) # 'Tokenize' the language code to do less operations + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"] + ] self.blocksize = blocksize self.warmup() @@ -56,18 +61,24 @@ class MeloTTSHandler(BaseHandler): language_id = None if isinstance(llm_sentence, tuple): - print("llm sentence is tuple!") llm_sentence, language_id = llm_sentence console.print(f"[green]ASSISTANT: {llm_sentence}") if language_id is not None and self.language != language_id: try: - self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id], device=self.device) - self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]] + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id], + device=self.device, + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id] + ] self.language = language_id except KeyError: - console.print(f"[red]Language {language_id} not supported by Melo. Using {self.language} instead.") + console.print( + f"[red]Language {language_id} not supported by Melo. Using {self.language} instead." + ) if self.device == "mps": import time @@ -79,7 +90,13 @@ class MeloTTSHandler(BaseHandler): time.time() - start ) # Removing this line makes it fail more often. I'm looking into it. - audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) + try: + audio_chunk = self.model.tts_to_file( + llm_sentence, self.speaker_id, quiet=True + ) + except (AssertionError, RuntimeError) as e: + logger.error(f"Error in MeloTTSHandler: {e}") + audio_chunk = np.array([]) if len(audio_chunk) == 0: self.should_listen.set() return diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 2438060982149f1ccf24073ddbaa72322ef950b1..8da829834e85c856458a571bf3c7242500d8ae6b 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -49,7 +49,6 @@ console = Console() logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs - def prepare_args(args, prefix): """ Rename arguments by removing the prefix and prepares the gen_kwargs.