From 2cb9464b8fd28799c7b622f44ef79f2485475add Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Thu, 29 Aug 2024 10:30:09 +0200 Subject: [PATCH] try to pass the language_id in the queue --- LLM/language_model.py | 4 ++-- STT/whisper_stt_handler.py | 8 +++----- TTS/melo_handler.py | 31 ++++++++++++++++--------------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/LLM/language_model.py b/LLM/language_model.py index bc39b23..e0d5829 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler): f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) - def process(self, prompt): + def process(self, prompt, language_id=None): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) @@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler): self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence - yield printable_text + yield (printable_text, language_id) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index e964a7c..a55800c 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler): pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) - for char in "<>|": - language_id = language_id.replace(char, "") # remove special tokens - global current_language - current_language = language_id + print("WHISPER curr lang", language_id) logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + console.print(f"[red]Language ID Whisper: {language_id}") - yield pred_text + yield (pred_text, language_id) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 06afcc6..d5cd53c 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -12,21 +12,21 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "en": "EN_NEWEST", - "fr": "FR", - "es": "ES", - "zh": "ZH", - "ja": "JP", - "ko": "KR", + "<|en|>": "EN_NEWEST", + "<|fr|>": "FR", + "<|es|>": "ES", + "<|zh|>": "ZH", + "<|ja|>": "JP", + "<|ko|>": "KR", } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "en": "EN-Newest", - "fr": "FR", - "es": "ES", - "zh": "ZH", - "ja": "JP", - "ko": "KR", + "<|en|>": "EN-Newest", + "<|fr|>": "FR", + "<|es|>": "ES", + "<|zh|>": "ZH", + "<|ja|>": "JP", + "<|ko|>": "KR", } @@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler): ): self.should_listen = should_listen self.device = device - self.language = language + self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device) self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]] self.blocksize = blocksize @@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler): logger.info(f"Warming up {self.__class__.__name__}") _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) - def process(self, llm_sentence): + def process(self, llm_sentence, language_id=None): console.print(f"[green]ASSISTANT: {llm_sentence}") - if self.language != current_language: + + if language_id is not None and self.language != language_id: self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device) self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]] -- GitLab