Skip to content
Snippets Groups Projects
Commit 2cb9464b authored by Andres Marafioti's avatar Andres Marafioti
Browse files

try to pass the language_id in the queue

parent 0555d4dc
No related branches found
No related tags found
No related merge requests found
......@@ -101,7 +101,7 @@ class LanguageModelHandler(BaseHandler):
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, prompt):
def process(self, prompt, language_id=None):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
......@@ -128,4 +128,4 @@ class LanguageModelHandler(BaseHandler):
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
yield (printable_text, language_id)
......@@ -108,13 +108,11 @@ class WhisperSTTHandler(BaseHandler):
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
for char in "<>|":
language_id = language_id.replace(char, "") # remove special tokens
global current_language
current_language = language_id
print("WHISPER curr lang", language_id)
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
console.print(f"[red]Language ID Whisper: {language_id}")
yield pred_text
yield (pred_text, language_id)
......@@ -12,21 +12,21 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
"<|en|>": "EN_NEWEST",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
"<|en|>": "EN-Newest",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
}
......@@ -42,7 +42,7 @@ class MeloTTSHandler(BaseHandler):
):
self.should_listen = should_listen
self.device = device
self.language = language
self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]]
self.blocksize = blocksize
......@@ -52,9 +52,10 @@ class MeloTTSHandler(BaseHandler):
logger.info(f"Warming up {self.__class__.__name__}")
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
def process(self, llm_sentence, language_id=None):
console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.language != current_language:
if language_id is not None and self.language != language_id:
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment