diff --git a/LLM/language_model.py b/LLM/language_model.py index bc39b23be98f22571eeb23e502c6026f83a0e174..e0d5829a138d87135690b8aca0d5a5db4093e7c3 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler): f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) - def process(self, prompt): + def process(self, prompt, language_id=None): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) @@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler): self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence - yield printable_text + yield (printable_text, language_id) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index e964a7c9f66fc33e69374b2246114a7642e62d73..a55800c2d75d5cd9ad20dbb853ea556f3b97fbf2 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler): pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) - for char in "<>|": - language_id = language_id.replace(char, "") # remove special tokens - global current_language - current_language = language_id + print("WHISPER curr lang", language_id) logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + console.print(f"[red]Language ID Whisper: {language_id}") - yield pred_text + yield (pred_text, language_id) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 06afcc635af4a209400cb039378ac401625db8de..d5cd53ce18c25906d2ba2d264f76633803cb7c83 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -12,21 +12,21 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "en": "EN_NEWEST", - "fr": "FR", - "es": "ES", - "zh": "ZH", - "ja": "JP", - "ko": "KR", + "<|en|>": "EN_NEWEST", + "<|fr|>": "FR", + "<|es|>": "ES", + "<|zh|>": "ZH", + "<|ja|>": "JP", + "<|ko|>": "KR", } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "en": "EN-Newest", - "fr": "FR", - "es": "ES", - "zh": "ZH", - "ja": "JP", - "ko": "KR", + "<|en|>": "EN-Newest", + "<|fr|>": "FR", + "<|es|>": "ES", + "<|zh|>": "ZH", + "<|ja|>": "JP", + "<|ko|>": "KR", } @@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler): ): self.should_listen = should_listen self.device = device - self.language = language + self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device) self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]] self.blocksize = blocksize @@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler): logger.info(f"Warming up {self.__class__.__name__}") _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) - def process(self, llm_sentence): + def process(self, llm_sentence, language_id=None): console.print(f"[green]ASSISTANT: {llm_sentence}") - if self.language != current_language: + + if language_id is not None and self.language != language_id: self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device) self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]]