diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 6aa165c7746ecc7a239125c4af1fe69c70ec8d63..f1f08c552308a1b42f99b7d015802d2eb55c6ce7 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -105,6 +105,12 @@ class WhisperSTTHandler(BaseHandler): pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] + language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + for char in "<>|": + language_id = language_id.replace(char, "") # remove special tokens + + global current_language + current_language = language_id logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 376747bc8982332ac66f8c6d41814f8bb2ef69e8..24e366d446990630585f090a672013c35fdffc22 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -10,21 +10,40 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { + "en": "EN_NEWEST", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + +WHISPER_LANGUAGE_TO_MELO_SPEAKER = { + "en": "EN-Newest", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", +} + class MeloTTSHandler(BaseHandler): def setup( self, should_listen, device="mps", - language="EN_NEWEST", - speaker_to_id="EN-Newest", + language="en", + speaker_to_id="en", gen_kwargs={}, # Unused blocksize=512, ): self.should_listen = should_listen self.device = device - self.model = TTS(language=language, device=device) - self.speaker_id = self.model.hps.data.spk2id[speaker_to_id] + self.language = language + self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device) + self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]] self.blocksize = blocksize self.warmup() @@ -34,6 +53,11 @@ class MeloTTSHandler(BaseHandler): def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") + global current_language + if self.language != current_language: + self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device) + self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]] + if self.device == "mps": import time diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 8da829834e85c856458a571bf3c7242500d8ae6b..d85ade8a3e5e5308f9d36104e58a973cb40ffa48 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -48,6 +48,8 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") console = Console() logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs +current_language = "en" + def prepare_args(args, prefix): """