Skip to content
Snippets Groups Projects
Commit 2cb9464b authored by Andres Marafioti's avatar Andres Marafioti
Browse files

try to pass the language_id in the queue

parent 0555d4dc
No related branches found
No related tags found
No related merge requests found
...@@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler): ...@@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler):
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
) )
def process(self, prompt): def process(self, prompt, language_id=None):
logger.debug("infering language model...") logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
...@@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler): ...@@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler):
self.chat.append({"role": "assistant", "content": generated_text}) self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence # don't forget last sentence
yield printable_text yield (printable_text, language_id)
...@@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler): ...@@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler):
pred_ids, skip_special_tokens=True, decode_with_timestamps=False pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0] )[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
for char in "<>|":
language_id = language_id.replace(char, "") # remove special tokens
global current_language print("WHISPER curr lang", language_id)
current_language = language_id
logger.debug("finished whisper inference") logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
console.print(f"[red]Language ID Whisper: {language_id}")
yield pred_text yield (pred_text, language_id)
...@@ -12,21 +12,21 @@ logger = logging.getLogger(__name__) ...@@ -12,21 +12,21 @@ logger = logging.getLogger(__name__)
console = Console() console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST", "<|en|>": "EN_NEWEST",
"fr": "FR", "<|fr|>": "FR",
"es": "ES", "<|es|>": "ES",
"zh": "ZH", "<|zh|>": "ZH",
"ja": "JP", "<|ja|>": "JP",
"ko": "KR", "<|ko|>": "KR",
} }
WHISPER_LANGUAGE_TO_MELO_SPEAKER = { WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest", "<|en|>": "EN-Newest",
"fr": "FR", "<|fr|>": "FR",
"es": "ES", "<|es|>": "ES",
"zh": "ZH", "<|zh|>": "ZH",
"ja": "JP", "<|ja|>": "JP",
"ko": "KR", "<|ko|>": "KR",
} }
...@@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler): ...@@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler):
): ):
self.should_listen = should_listen self.should_listen = should_listen
self.device = device self.device = device
self.language = language self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device) self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]] self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]]
self.blocksize = blocksize self.blocksize = blocksize
...@@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler): ...@@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler):
logger.info(f"Warming up {self.__class__.__name__}") logger.info(f"Warming up {self.__class__.__name__}")
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True) _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence): def process(self, llm_sentence, language_id=None):
console.print(f"[green]ASSISTANT: {llm_sentence}") console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.language != current_language:
if language_id is not None and self.language != language_id:
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device) self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]] self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment