From 29fedd0720cf3fe60b03f635471b3b5e15c79b01 Mon Sep 17 00:00:00 2001
From: Elli0t <ybmelliot@gmail.com>
Date: Sun, 8 Sep 2024 02:51:23 +0800
Subject: [PATCH] Update: Added multi-language support for macOS

---
 LLM/mlx_language_model.py            | 17 +++++++++++++
 STT/lightning_whisper_mlx_handler.py | 37 ++++++++++++++++++++++++++--
 TTS/melo_handler.py                  |  4 +--
 3 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py
index 7e041de..191b1d6 100644
--- a/LLM/mlx_language_model.py
+++ b/LLM/mlx_language_model.py
@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
 
 console = Console()
 
+WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
+    "en": "english",
+    "fr": "french",
+    "es": "spanish",
+    "zh": "chinese",
+    "ja": "japanese",
+    "ko": "korean",
+}
 
 class MLXLanguageModelHandler(BaseHandler):
     """
@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
 
     def process(self, prompt):
         logger.debug("infering language model...")
+        language_code = None
+
+        if isinstance(prompt, tuple):
+            prompt, language_code = prompt
+            prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
 
         self.chat.append({"role": self.user_role, "content": prompt})
 
@@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler):
                 yield curr_output.replace("<|end|>", "")
                 curr_output = ""
         generated_text = output.replace("<|end|>", "")
+        printable_text = generated_text
         torch.mps.empty_cache()
 
         self.chat.append({"role": "assistant", "content": generated_text})
+
+        # don't forget last sentence
+        yield (printable_text, language_code)
\ No newline at end of file
diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py
index 4785b73..9d48740 100644
--- a/STT/lightning_whisper_mlx_handler.py
+++ b/STT/lightning_whisper_mlx_handler.py
@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
 from lightning_whisper_mlx import LightningWhisperMLX
 import numpy as np
 from rich.console import Console
+from copy import copy
 import torch
 
 logger = logging.getLogger(__name__)
 
 console = Console()
 
+SUPPORTED_LANGUAGES = [
+    "en",
+    "fr",
+    "es",
+    "zh",
+    "ja",
+    "ko",
+]
+
 
 class LightningWhisperSTTHandler(BaseHandler):
     """
@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler):
     def setup(
         self,
         model_name="distil-large-v3",
-        device="cuda",
+        device="mps",
         torch_dtype="float16",
         compile_mode=None,
         language=None,
@@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler):
             model_name = model_name.split("/")[-1]
         self.device = device
         self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
+        if language == 'auto':
+            language = None
+        self.last_language = language
+        if self.last_language is not None:
+            self.gen_kwargs["language"] = self.last_language
+
         self.warmup()
 
     def warmup(self):
@@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler):
         global pipeline_start
         pipeline_start = perf_counter()
 
+        # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2]  # remove "<|" and "|>"
+
+        language_code = self.model.transcribe(spoken_prompt)["language"]
+
+        if language_code not in SUPPORTED_LANGUAGES:  # reprocess with the last language
+            logger.warning("Whisper detected unsupported language:", language_code)
+            gen_kwargs = copy(self.gen_kwargs) #####
+            gen_kwargs['language'] = self.last_language
+            language_code = self.last_language
+            # pred_ids = self.model.generate(input_features, **gen_kwargs)
+        else:
+            self.last_language = language_code
+
         pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
         torch.mps.empty_cache()
 
+        # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
+        language_code = self.model.transcribe(spoken_prompt)["language"]
+
         logger.debug("finished whisper inference")
         console.print(f"[yellow]USER: {pred_text}")
+        logger.debug(f"Language Code Whisper: {language_code}")
 
-        yield pred_text
+        yield (pred_text, language_code)
diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py
index b1b2226..6dd50f1 100644
--- a/TTS/melo_handler.py
+++ b/TTS/melo_handler.py
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
 console = Console()
 
 WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
-    "en": "EN_NEWEST",
+    "en": "EN",
     "fr": "FR",
     "es": "ES",
     "zh": "ZH",
@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
 }
 
 WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
-    "en": "EN-Newest",
+    "en": "EN-BR",
     "fr": "FR",
     "es": "ES",
     "zh": "ZH",
-- 
GitLab