From 901d9a1402d6244083e1fc8243b163f210381145 Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Thu, 29 Aug 2024 10:37:44 +0200
Subject: [PATCH] handle language better

---
 LLM/language_model.py | 5 ++++-
 TTS/melo_handler.py   | 5 ++++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/LLM/language_model.py b/LLM/language_model.py
index e0d5829..cb24c89 100644
--- a/LLM/language_model.py
+++ b/LLM/language_model.py
@@ -101,8 +101,11 @@ class LanguageModelHandler(BaseHandler):
                 f"{self.__class__.__name__}:  warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
             )
 
-    def process(self, prompt, language_id=None):
+    def process(self, prompt):
         logger.debug("infering language model...")
+        language_id = None
+        if isinstance(prompt, tuple):
+            prompt, language_id = prompt
 
         self.chat.append({"role": self.user_role, "content": prompt})
         thread = Thread(
diff --git a/TTS/melo_handler.py b/TTS/melo_handler.py
index 6b5d831..45eef88 100644
--- a/TTS/melo_handler.py
+++ b/TTS/melo_handler.py
@@ -52,8 +52,11 @@ class MeloTTSHandler(BaseHandler):
         logger.info(f"Warming up {self.__class__.__name__}")
         _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
 
-    def process(self, llm_sentence, language_id=None):
+    def process(self, llm_sentence):
         console.print(f"[green]ASSISTANT: {llm_sentence}")
+        language_id = None
+        if isinstance(prompt, tuple):
+            prompt, language_id = prompt
 
         if language_id is not None and self.language != language_id:
             self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)
-- 
GitLab