From 65f779de83a374830ee37d9750a5b3ddecfdca95 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Wed, 4 Sep 2024 13:39:58 +0200
Subject: [PATCH] review from eustache

---
 LLM/language_model.py      | 22 +++++++++----------
 STT/whisper_stt_handler.py | 28 ++++++++++++------------
 TTS/melo_handler.py        | 44 ++++++++++++++++++--------------------
 3 files changed, 46 insertions(+), 48 deletions(-)

diff --git a/LLM/language_model.py b/LLM/language_model.py
index 6b48017..1a95762 100644
--- a/LLM/language_model.py
+++ b/LLM/language_model.py
@@ -19,12 +19,12 @@ console = Console()
 
 
 WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
-    "<|en|>": "english",
-    "<|fr|>": "french",
-    "<|es|>": "spanish",
-    "<|zh|>": "chinese",
-    "<|ja|>": "japanese",
-    "<|ko|>": "korean",
+    "en": "english",
+    "fr": "french",
+    "es": "spanish",
+    "zh": "chinese",
+    "ja": "japanese",
+    "ko": "korean",
 }
 
 class LanguageModelHandler(BaseHandler):
@@ -112,10 +112,10 @@ class LanguageModelHandler(BaseHandler):
 
     def process(self, prompt):
         logger.debug("infering language model...")
-        language_id = None
+        language_code = None
         if isinstance(prompt, tuple):
-            prompt, language_id = prompt
-            prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
+            prompt, language_code = prompt
+            prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
 
         self.chat.append({"role": self.user_role, "content": prompt})
         thread = Thread(
@@ -135,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
                 printable_text += new_text
                 sentences = sent_tokenize(printable_text)
                 if len(sentences) > 1:
-                    yield (sentences[0], language_id)
+                    yield (sentences[0], language_code)
                     printable_text = new_text
 
         self.chat.append({"role": "assistant", "content": generated_text})
 
         # don't forget last sentence
-        yield (printable_text, language_id)
+        yield (printable_text, language_code)
diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py
index d669e34..30d9307 100644
--- a/STT/whisper_stt_handler.py
+++ b/STT/whisper_stt_handler.py
@@ -13,12 +13,12 @@ logger = logging.getLogger(__name__)
 console = Console()
 
 SUPPORTED_LANGUAGES = [
-    "<|en|>",
-    "<|fr|>",
-    "<|es|>",
-    "<|zh|>",
-    "<|ja|>",
-    "<|ko|>",
+    "en",
+    "fr",
+    "es",
+    "zh",
+    "ja",
+    "ko",
 ]
 
 
@@ -117,24 +117,24 @@ class WhisperSTTHandler(BaseHandler):
 
         input_features = self.prepare_model_inputs(spoken_prompt)
         pred_ids = self.model.generate(input_features, **self.gen_kwargs)
-        language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
+        language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2]  # remove "<|" and "|>"
 
-        if language_id not in SUPPORTED_LANGUAGES:  # reprocess with the last language
-            logger.warning("Whisper detected unsupported language:", language_id)
+        if language_code not in SUPPORTED_LANGUAGES:  # reprocess with the last language
+            logger.warning("Whisper detected unsupported language:", language_code)
             gen_kwargs = copy(self.gen_kwargs)
             gen_kwargs['language'] = self.last_language
-            language_id = self.last_language
+            language_code = self.last_language
             pred_ids = self.model.generate(input_features, **gen_kwargs)
         else:
-            self.last_language = language_id
+            self.last_language = language_code
         
         pred_text = self.processor.batch_decode(
             pred_ids, skip_special_tokens=True, decode_with_timestamps=False
         )[0]
-        language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
+        language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
 
         logger.debug("finished whisper inference")
         console.print(f"[yellow]USER: {pred_text}")
-        logger.debug(f"Language ID Whisper: {language_id}")
+        logger.debug(f"Language Code Whisper: {language_code}")
 
-        yield (pred_text, language_id)
+        yield (pred_text, language_code)
diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py
index 64f371d..b1b2226 100644
--- a/TTS/melo_handler.py
+++ b/TTS/melo_handler.py
@@ -11,21 +11,21 @@ logger = logging.getLogger(__name__)
 console = Console()
 
 WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
-    "<|en|>": "EN_NEWEST",
-    "<|fr|>": "FR",
-    "<|es|>": "ES",
-    "<|zh|>": "ZH",
-    "<|ja|>": "JP",
-    "<|ko|>": "KR",
+    "en": "EN_NEWEST",
+    "fr": "FR",
+    "es": "ES",
+    "zh": "ZH",
+    "ja": "JP",
+    "ko": "KR",
 }
 
 WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
-    "<|en|>": "EN-Newest",
-    "<|fr|>": "FR",
-    "<|es|>": "ES",
-    "<|zh|>": "ZH",
-    "<|ja|>": "JP",
-    "<|ko|>": "KR",
+    "en": "EN-Newest",
+    "fr": "FR",
+    "es": "ES",
+    "zh": "ZH",
+    "ja": "JP",
+    "ko": "KR",
 }
 
 
@@ -41,14 +41,12 @@ class MeloTTSHandler(BaseHandler):
     ):
         self.should_listen = should_listen
         self.device = device
-        self.language = (
-            "<|" + language + "|>"
-        )  # 'Tokenize' the language code to do less operations
+        self.language = language
         self.model = TTS(
             language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
         )
         self.speaker_id = self.model.hps.data.spk2id[
-            WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]
+            WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
         ]
         self.blocksize = blocksize
         self.warmup()
@@ -58,26 +56,26 @@ class MeloTTSHandler(BaseHandler):
         _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
 
     def process(self, llm_sentence):
-        language_id = None
+        language_code = None
 
         if isinstance(llm_sentence, tuple):
-            llm_sentence, language_id = llm_sentence
+            llm_sentence, language_code = llm_sentence
 
         console.print(f"[green]ASSISTANT: {llm_sentence}")
 
-        if language_id is not None and self.language != language_id:
+        if language_code is not None and self.language != language_code:
             try:
                 self.model = TTS(
-                    language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id],
+                    language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
                     device=self.device,
                 )
                 self.speaker_id = self.model.hps.data.spk2id[
-                    WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]
+                    WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
                 ]
-                self.language = language_id
+                self.language = language_code
             except KeyError:
                 console.print(
-                    f"[red]Language {language_id} not supported by Melo. Using {self.language} instead."
+                    f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
                 )
 
         if self.device == "mps":
-- 
GitLab