diff --git a/LLM/language_model.py b/LLM/language_model.py index e0d5829a138d87135690b8aca0d5a5db4093e7c3..cb24c895d04e912d9ca06092a082f29b9d01b2b1 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -101,8 +101,11 @@ class LanguageModelHandler(BaseHandler): f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) - def process(self, prompt, language_id=None): + def process(self, prompt): logger.debug("infering language model...") + language_id = None + if isinstance(prompt, tuple): + prompt, language_id = prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py index 6b5d83147e16e6c1fdf87c3d0f45daffc09f1aa0..45eef88f2297b946d6602961b13ec6ea1e3987c1 100644 --- a/TTS/melo_handler.py +++ b/TTS/melo_handler.py @@ -52,8 +52,11 @@ class MeloTTSHandler(BaseHandler): logger.info(f"Warming up {self.__class__.__name__}") _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) - def process(self, llm_sentence, language_id=None): + def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") + language_id = None + if isinstance(prompt, tuple): + prompt, language_id = prompt if language_id is not None and self.language != language_id: self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)