From cfd3065e347bc0cc6308b2c9b30af0e9acc8e1e6 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Wed, 28 Aug 2024 17:07:32 +0200 Subject: [PATCH] language fun --- STT/whisper_stt_handler.py | 6 ++++++ TTS/melo_handler.py | 32 ++++++++++++++++++++++++++++---- s2s_pipeline.py | 2 ++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 6aa165c..f1f08c5 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -105,6 +105,12 @@ class WhisperSTTHandler(BaseHandler): pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] + language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + for char in "<>|": + language_id = language_id.replace(char, "") # remove special tokens + + global current_language + current_language = language_id logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 376747b..24e366d 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -10,21 +10,40 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { + "en": "EN_NEWEST", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + +WHISPER_LANGUAGE_TO_MELO_SPEAKER = { + "en": "EN-Newest", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + class MeloTTSHandler(BaseHandler): def setup( self, should_listen, device="mps", - language="EN_NEWEST", - speaker_to_id="EN-Newest", + language="en", + speaker_to_id="en", gen_kwargs={}, # Unused blocksize=512, ): self.should_listen = should_listen self.device = device - self.model = TTS(language=language, device=device) - self.speaker_id = self.model.hps.data.spk2id[speaker_to_id] + self.language = language + self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device) + self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]] self.blocksize = blocksize self.warmup() @@ -34,6 +53,11 @@ class MeloTTSHandler(BaseHandler): def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") + global current_language + if self.language != current_language: + self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device) + self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]] + if self.device == "mps": import time diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 8da8298..d85ade8 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -48,6 +48,8 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") console = Console() logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs +current_language = "en" + def prepare_args(args, prefix): """ -- GitLab