From 29fedd0720cf3fe60b03f635471b3b5e15c79b01 Mon Sep 17 00:00:00 2001 From: Elli0t <ybmelliot@gmail.com> Date: Sun, 8 Sep 2024 02:51:23 +0800 Subject: [PATCH] Update: Added multi-language support for macOS --- LLM/mlx_language_model.py | 17 +++++++++++++ STT/lightning_whisper_mlx_handler.py | 37 ++++++++++++++++++++++++++-- TTS/melo_handler.py | 4 +-- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 7e041de..191b1d6 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -9,6 +9,14 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} class MLXLanguageModelHandler(BaseHandler): """ @@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler): def process(self, prompt): logger.debug("infering language model...") + language_code = None + + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) @@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler): yield curr_output.replace("<|end|>", "") curr_output = "" generated_text = output.replace("<|end|>", "") + printable_text = generated_text torch.mps.empty_cache() self.chat.append({"role": "assistant", "content": generated_text}) + + # don't forget last sentence + yield (printable_text, language_code) \ No newline at end of file diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 4785b73..9d48740 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -4,12 +4,22 @@ from baseHandler import BaseHandler from lightning_whisper_mlx import LightningWhisperMLX import numpy as np from rich.console import Console +from copy import copy import torch logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + class LightningWhisperSTTHandler(BaseHandler): """ @@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler): def setup( self, model_name="distil-large-v3", - device="cuda", + device="mps", torch_dtype="float16", compile_mode=None, language=None, @@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler): model_name = model_name.split("/")[-1] self.device = device self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) + if language == 'auto': + language = None + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language + self.warmup() def warmup(self): @@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler): global pipeline_start pipeline_start = perf_counter() + # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + + language_code = self.model.transcribe(spoken_prompt)["language"] + + if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_code) + gen_kwargs = copy(self.gen_kwargs) ##### + gen_kwargs['language'] = self.last_language + language_code = self.last_language + # pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_code + pred_text = self.model.transcribe(spoken_prompt)["text"].strip() torch.mps.empty_cache() + # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + language_code = self.model.transcribe(spoken_prompt)["language"] + logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") - yield pred_text + yield (pred_text, language_code) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index b1b2226..6dd50f1 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "en": "EN_NEWEST", + "en": "EN", "fr": "FR", "es": "ES", "zh": "ZH", @@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "en": "EN-Newest", + "en": "EN-BR", "fr": "FR", "es": "ES", "zh": "ZH", -- GitLab