diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 7e041de3e8bcb82ed9332eb778c6b03884f010b0..191b1d67371f3174519272149c3590ffc32e97dc 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -9,6 +9,14 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} class MLXLanguageModelHandler(BaseHandler): """ @@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler): def process(self, prompt): logger.debug("infering language model...") + language_code = None + + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) @@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler): yield curr_output.replace("<|end|>", "") curr_output = "" generated_text = output.replace("<|end|>", "") + printable_text = generated_text torch.mps.empty_cache() self.chat.append({"role": "assistant", "content": generated_text}) + + # don't forget last sentence + yield (printable_text, language_code) \ No newline at end of file diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 4785b73853275e5a308ae9da48667b0b656297df..9d48740b3f3d04efbc8e991a84e00fb226c238e6 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -4,12 +4,22 @@ from baseHandler import BaseHandler from lightning_whisper_mlx import LightningWhisperMLX import numpy as np from rich.console import Console +from copy import copy import torch logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + class LightningWhisperSTTHandler(BaseHandler): """ @@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler): def setup( self, model_name="distil-large-v3", - device="cuda", + device="mps", torch_dtype="float16", compile_mode=None, language=None, @@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler): model_name = model_name.split("/")[-1] self.device = device self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) + if language == 'auto': + language = None + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language + self.warmup() def warmup(self): @@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler): global pipeline_start pipeline_start = perf_counter() + # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + + language_code = self.model.transcribe(spoken_prompt)["language"] + + if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_code) + gen_kwargs = copy(self.gen_kwargs) ##### + gen_kwargs['language'] = self.last_language + language_code = self.last_language + # pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_code + pred_text = self.model.transcribe(spoken_prompt)["text"].strip() torch.mps.empty_cache() + # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" + language_code = self.model.transcribe(spoken_prompt)["language"] + logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") - yield pred_text + yield (pred_text, language_code) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index b1b222614da719cc1cc50a1826892d661f859105..6dd50f1330d49f6edc7b493bcf292b3b158304a6 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "en": "EN_NEWEST", + "en": "EN", "fr": "FR", "es": "ES", "zh": "ZH", @@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "en": "EN-Newest", + "en": "EN-BR", "fr": "FR", "es": "ES", "zh": "ZH",