From 2cb9464b8fd28799c7b622f44ef79f2485475add Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Thu, 29 Aug 2024 10:30:09 +0200
Subject: [PATCH] try to pass the language_id in the queue

---
 LLM/language_model.py      |  4 ++--
 STT/whisper_stt_handler.py |  8 +++-----
 TTS/melo_handler.py        | 31 ++++++++++++++++---------------
 3 files changed, 21 insertions(+), 22 deletions(-)

diff --git a/LLM/language_model.py b/LLM/language_model.py
index bc39b23..e0d5829 100644
--- a/LLM/language_model.py
+++ b/LLM/language_model.py
@@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler):
                 f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
             )
 
-    def process(self, prompt):
+    def process(self, prompt, language_id=None):
         logger.debug("infering language model...")
 
         self.chat.append({"role": self.user_role, "content": prompt})
@@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler):
         self.chat.append({"role": "assistant", "content": generated_text})
 
         # don't forget last sentence
-        yield printable_text
+        yield (printable_text, language_id)
diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py
index e964a7c..a55800c 100644
--- a/STT/whisper_stt_handler.py
+++ b/STT/whisper_stt_handler.py
@@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler):
             pred_ids, skip_special_tokens=True, decode_with_timestamps=False
         )[0]
         language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
-        for char in "<>|":
-            language_id = language_id.replace(char, "") # remove special tokens
 
-        global current_language
-        current_language = language_id
+        print("WHISPER curr lang", language_id)
 
         logger.debug("finished whisper inference")
         console.print(f"[yellow]USER: {pred_text}")
+        console.print(f"[red]Language ID Whisper: {language_id}")
 
-        yield pred_text
+        yield (pred_text, language_id)
diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py
index 06afcc6..d5cd53c 100644
--- a/TTS/melo_handler.py
+++ b/TTS/melo_handler.py
@@ -12,21 +12,21 @@ logger = logging.getLogger(__name__)
 console = Console()
 
 WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
-    "en": "EN_NEWEST",
-    "fr": "FR",
-    "es": "ES",
-    "zh": "ZH",
-    "ja": "JP",
-    "ko": "KR",
+    "<|en|>": "EN_NEWEST",
+    "<|fr|>": "FR",
+    "<|es|>": "ES",
+    "<|zh|>": "ZH",
+    "<|ja|>": "JP",
+    "<|ko|>": "KR",
 }
 
 WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
-    "en": "EN-Newest",
-    "fr": "FR",
-    "es": "ES",
-    "zh": "ZH",
-    "ja": "JP",
-    "ko": "KR",
+    "<|en|>": "EN-Newest",
+    "<|fr|>": "FR",
+    "<|es|>": "ES",
+    "<|zh|>": "ZH",
+    "<|ja|>": "JP",
+    "<|ko|>": "KR",
 }
 
 
@@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler):
     ):
         self.should_listen = should_listen
         self.device = device
-        self.language = language
+        self.language = "<|" + language + "|>"  # 'Tokenize' the language code to do less operations
         self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device)
         self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]]
         self.blocksize = blocksize
@@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler):
         logger.info(f"Warming up {self.__class__.__name__}")
         _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
 
-    def process(self, llm_sentence):
+    def process(self, llm_sentence, language_id=None):
         console.print(f"[green]ASSISTANT: {llm_sentence}")
-        if self.language != current_language:
+
+        if language_id is not None and self.language != language_id:
             self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)
             self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]]
 
-- 
GitLab