diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 191b1d67371f3174519272149c3590ffc32e97dc..ae11b35e7e99e608cea493a5184d18734e2a54eb 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -99,13 +99,9 @@ class MLXLanguageModelHandler(BaseHandler): output += t curr_output += t if curr_output.endswith((".", "?", "!", "<|end|>")): - yield curr_output.replace("<|end|>", "") + yield (curr_output.replace("<|end|>", ""), language_code) curr_output = "" generated_text = output.replace("<|end|>", "") - printable_text = generated_text torch.mps.empty_cache() - self.chat.append({"role": "assistant", "content": generated_text}) - - # don't forget last sentence - yield (printable_text, language_code) \ No newline at end of file + self.chat.append({"role": "assistant", "content": generated_text}) \ No newline at end of file diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 9d48740b3f3d04efbc8e991a84e00fb226c238e6..6f9fbb217bb32b227096e4594f6e648fa879d37a 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -39,11 +39,8 @@ class LightningWhisperSTTHandler(BaseHandler): model_name = model_name.split("/")[-1] self.device = device self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) - if language == 'auto': - language = None + self.start_language = language self.last_language = language - if self.last_language is not None: - self.gen_kwargs["language"] = self.last_language self.warmup() @@ -63,25 +60,24 @@ class LightningWhisperSTTHandler(BaseHandler): global pipeline_start pipeline_start = perf_counter() - # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" - - language_code = self.model.transcribe(spoken_prompt)["language"] - - if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language - logger.warning("Whisper detected unsupported language:", language_code) - gen_kwargs = copy(self.gen_kwargs) ##### - gen_kwargs['language'] = self.last_language - language_code = self.last_language - # pred_ids = self.model.generate(input_features, **gen_kwargs) + if self.start_language != 'auto': + transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) else: - self.last_language = language_code - - pred_text = self.model.transcribe(spoken_prompt)["text"].strip() + transcription_dict = self.model.transcribe(spoken_prompt) + language_code = transcription_dict["language"] + if language_code not in SUPPORTED_LANGUAGES: + logger.warning(f"Whisper detected unsupported language: {language_code}") + if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language + transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) + else: + transcription_dict = {"text": "", "language": "en"} + else: + self.last_language = language_code + + pred_text = transcription_dict["text"].strip() + language_code = transcription_dict["language"] torch.mps.empty_cache() - # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" - language_code = self.model.transcribe(spoken_prompt)["language"] - logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") logger.debug(f"Language Code Whisper: {language_code}") diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 6dd50f1330d49f6edc7b493bcf292b3b158304a6..b1b222614da719cc1cc50a1826892d661f859105 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "en": "EN", + "en": "EN_NEWEST", "fr": "FR", "es": "ES", "zh": "ZH", @@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "en": "EN-BR", + "en": "EN-Newest", "fr": "FR", "es": "ES", "zh": "ZH",