From 65f779de83a374830ee37d9750a5b3ddecfdca95 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Wed, 4 Sep 2024 13:39:58 +0200 Subject: [PATCH] review from eustache --- LLM/language_model.py | 22 +++++++++---------- STT/whisper_stt_handler.py | 28 ++++++++++++------------ TTS/melo_handler.py | 44 ++++++++++++++++++-------------------- 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/LLM/language_model.py b/LLM/language_model.py index 6b48017..1a95762 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -19,12 +19,12 @@ console = Console() WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { - "<|en|>": "english", - "<|fr|>": "french", - "<|es|>": "spanish", - "<|zh|>": "chinese", - "<|ja|>": "japanese", - "<|ko|>": "korean", + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", } class LanguageModelHandler(BaseHandler): @@ -112,10 +112,10 @@ class LanguageModelHandler(BaseHandler): def process(self, prompt): logger.debug("infering language model...") - language_id = None + language_code = None if isinstance(prompt, tuple): - prompt, language_id = prompt - prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( @@ -135,10 +135,10 @@ class LanguageModelHandler(BaseHandler): printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: - yield (sentences[0], language_id) + yield (sentences[0], language_code) printable_text = new_text self.chat.append({"role": "assistant", "content": generated_text}) # don't forget last sentence - yield (printable_text, language_id) + yield (printable_text, language_code) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index d669e34..30d9307 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -13,12 +13,12 @@ logger = logging.getLogger(__name__) console = Console() SUPPORTED_LANGUAGES = [ - "<|en|>", - "<|fr|>", - "<|es|>", - "<|zh|>", - "<|ja|>", - "<|ko|>", + "en", + "fr", + "es", + "zh", + "ja", + "ko", ] @@ -117,24 +117,24 @@ class WhisperSTTHandler(BaseHandler): input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) - language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" - if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language - logger.warning("Whisper detected unsupported language:", language_id) + if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_code) gen_kwargs = copy(self.gen_kwargs) gen_kwargs['language'] = self.last_language - language_id = self.last_language + language_code = self.last_language pred_ids = self.model.generate(input_features, **gen_kwargs) else: - self.last_language = language_id + self.last_language = language_code pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] - language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") - logger.debug(f"Language ID Whisper: {language_id}") + logger.debug(f"Language Code Whisper: {language_code}") - yield (pred_text, language_id) + yield (pred_text, language_code) diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 64f371d..b1b2226 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -11,21 +11,21 @@ logger = logging.getLogger(__name__) console = Console() WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { - "<|en|>": "EN_NEWEST", - "<|fr|>": "FR", - "<|es|>": "ES", - "<|zh|>": "ZH", - "<|ja|>": "JP", - "<|ko|>": "KR", + "en": "EN_NEWEST", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", } WHISPER_LANGUAGE_TO_MELO_SPEAKER = { - "<|en|>": "EN-Newest", - "<|fr|>": "FR", - "<|es|>": "ES", - "<|zh|>": "ZH", - "<|ja|>": "JP", - "<|ko|>": "KR", + "en": "EN-Newest", + "fr": "FR", + "es": "ES", + "zh": "ZH", + "ja": "JP", + "ko": "KR", } @@ -41,14 +41,12 @@ class MeloTTSHandler(BaseHandler): ): self.should_listen = should_listen self.device = device - self.language = ( - "<|" + language + "|>" - ) # 'Tokenize' the language code to do less operations + self.language = language self.model = TTS( language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device ) self.speaker_id = self.model.hps.data.spk2id[ - WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"] + WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] ] self.blocksize = blocksize self.warmup() @@ -58,26 +56,26 @@ class MeloTTSHandler(BaseHandler): _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) def process(self, llm_sentence): - language_id = None + language_code = None if isinstance(llm_sentence, tuple): - llm_sentence, language_id = llm_sentence + llm_sentence, language_code = llm_sentence console.print(f"[green]ASSISTANT: {llm_sentence}") - if language_id is not None and self.language != language_id: + if language_code is not None and self.language != language_code: try: self.model = TTS( - language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id], + language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], device=self.device, ) self.speaker_id = self.model.hps.data.spk2id[ - WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id] + WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] ] - self.language = language_id + self.language = language_code except KeyError: console.print( - f"[red]Language {language_id} not supported by Melo. Using {self.language} instead." + f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." ) if self.device == "mps": -- GitLab