diff --git a/LLM/language_model.py b/LLM/language_model.py index bc39b23be98f22571eeb23e502c6026f83a0e174..1a957625b43193f936898ea68e4b48d810e86101 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -18,6 +18,15 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} + class LanguageModelHandler(BaseHandler): """ Handles the language model part. @@ -69,7 +78,7 @@ class LanguageModelHandler(BaseHandler): def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") - dummy_input_text = "Write me a poem about Machine Learning." + dummy_input_text = "Repeat the word 'home'." dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs["min_new_tokens"], @@ -103,6 +112,10 @@ class LanguageModelHandler(BaseHandler): def process(self, prompt): logger.debug("infering language model...") + language_code = None + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( @@ -122,10 +135,10 @@ class LanguageModelHandler(BaseHandler): printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: - yield (sentences[0]) + yield (sentences[0], language_code) printable_text = new_text self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence - yield printable_text + yield (printable_text, language_code) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 6aa165c7746ecc7a239125c4af1fe69c70ec8d63..06cf613b01156aadcf0ccee9533c0d7aff039930 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -1,10 +1,10 @@ from time import perf_counter from transformers import ( - AutoModelForSpeechSeq2Seq, AutoProcessor, + AutoModelForSpeechSeq2Seq ) import torch - +from copy import copy from baseHandler import BaseHandler from rich.console import Console import logging @@ -12,6 +12,15 @@ import logging logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + class WhisperSTTHandler(BaseHandler): """ @@ -24,12 +33,18 @@ class WhisperSTTHandler(BaseHandler): device="cuda", torch_dtype="float16", compile_mode=None, + language=None, gen_kwargs={}, ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs + if language == 'auto': + language = None + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -102,11 +117,24 @@ class WhisperSTTHandler(BaseHandler): input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + + if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_code) + gen_kwargs = copy(self.gen_kwargs) + gen_kwargs['language'] = self.last_language + language_code = self.last_language + pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_code + pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") - yield pred_text + yield (pred_text, language_code) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 376747bc8982332ac66f8c6d41814f8bb2ef69e8..b1b222614da719cc1cc50a1826892d661f859105 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -10,21 +10,44 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { + "en": "EN_NEWEST", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + +WHISPER_LANGUAGE_TO_MELO_SPEAKER = { + "en": "EN-Newest", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + class MeloTTSHandler(BaseHandler): def setup( self, should_listen, device="mps", - language="EN_NEWEST", - speaker_to_id="EN-Newest", + language="en", + speaker_to_id="en", gen_kwargs={}, # Unused blocksize=512, ): self.should_listen = should_listen self.device = device - self.model = TTS(language=language, device=device) - self.speaker_id = self.model.hps.data.spk2id[speaker_to_id] + self.language = language + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] + ] self.blocksize = blocksize self.warmup() @@ -33,7 +56,28 @@ class MeloTTSHandler(BaseHandler): _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) def process(self, llm_sentence): + language_code = None + + if isinstance(llm_sentence, tuple): + llm_sentence, language_code = llm_sentence + console.print(f"[green]ASSISTANT: {llm_sentence}") + + if language_code is not None and self.language != language_code: + try: + self.model = TTS( + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], + device=self.device, + ) + self.speaker_id = self.model.hps.data.spk2id[ + WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] + ] + self.language = language_code + except KeyError: + console.print( + f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." + ) + if self.device == "mps": import time @@ -44,7 +88,13 @@ class MeloTTSHandler(BaseHandler): time.time() - start ) # Removing this line makes it fail more often. I'm looking into it. - audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) + try: + audio_chunk = self.model.tts_to_file( + llm_sentence, self.speaker_id, quiet=True + ) + except (AssertionError, RuntimeError) as e: + logger.error(f"Error in MeloTTSHandler: {e}") + audio_chunk = np.array([]) if len(audio_chunk) == 0: self.should_listen.set() return diff --git a/arguments_classes/melo_tts_arguments.py b/arguments_classes/melo_tts_arguments.py index 49fd3578ac0e376ab264baa7900e5fc0cc879c76..7223489318f843c2919b7d3929580040c88307aa 100644 --- a/arguments_classes/melo_tts_arguments.py +++ b/arguments_classes/melo_tts_arguments.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field @dataclass class MeloTTSHandlerArguments: melo_language: str = field( - default="EN_NEWEST", + default="en", metadata={ "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." }, @@ -16,7 +16,7 @@ class MeloTTSHandlerArguments: }, ) melo_speaker_to_id: str = field( - default="EN-Newest", + default="en", metadata={ "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." }, diff --git a/arguments_classes/whisper_stt_arguments.py b/arguments_classes/whisper_stt_arguments.py index bed382dda754da36965b4d86e68a7f8b4d9c322c..5dc700bf24e2320d0065ab6db40c0adbcf4782b5 100644 --- a/arguments_classes/whisper_stt_arguments.py +++ b/arguments_classes/whisper_stt_arguments.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Optional @dataclass @@ -51,9 +52,13 @@ class WhisperSTTHandlerArguments: "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." }, ) - stt_gen_language: str = field( - default="en", + language: Optional[str] = field( + default='en', metadata={ - "help": "The language of the speech to transcribe. Default is 'en' for English." + "help": """The language for the conversation. + Choose between 'en' (english), 'fr' (french), 'es' (spanish), + 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'. + If using 'auto', the language is automatically detected and can + change during the conversation. Default is 'en'.""" }, - ) + ) \ No newline at end of file