Skip to content
Snippets Groups Projects
Commit 65f779de authored by Andres Marafioti's avatar Andres Marafioti
Browse files

review from eustache

parent 65bef760
No related branches found
No related tags found
No related merge requests found
......@@ -19,12 +19,12 @@ console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"<|en|>": "english",
"<|fr|>": "french",
"<|es|>": "spanish",
"<|zh|>": "chinese",
"<|ja|>": "japanese",
"<|ko|>": "korean",
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class LanguageModelHandler(BaseHandler):
......@@ -112,10 +112,10 @@ class LanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_id = None
language_code = None
if isinstance(prompt, tuple):
prompt, language_id = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
......@@ -135,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0], language_id)
yield (sentences[0], language_code)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield (printable_text, language_id)
yield (printable_text, language_code)
......@@ -13,12 +13,12 @@ logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"<|en|>",
"<|fr|>",
"<|es|>",
"<|zh|>",
"<|ja|>",
"<|ko|>",
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
......@@ -117,24 +117,24 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_id)
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_id = self.last_language
language_code = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_id
self.last_language = language_code
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language ID Whisper: {language_id}")
logger.debug(f"Language Code Whisper: {language_code}")
yield (pred_text, language_id)
yield (pred_text, language_code)
......@@ -11,21 +11,21 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"<|en|>": "EN_NEWEST",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"<|en|>": "EN-Newest",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
......@@ -41,14 +41,12 @@ class MeloTTSHandler(BaseHandler):
):
self.should_listen = should_listen
self.device = device
self.language = (
"<|" + language + "|>"
) # 'Tokenize' the language code to do less operations
self.language = language
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize
self.warmup()
......@@ -58,26 +56,26 @@ class MeloTTSHandler(BaseHandler):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
language_id = None
language_code = None
if isinstance(llm_sentence, tuple):
llm_sentence, language_id = llm_sentence
llm_sentence, language_code = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_id is not None and self.language != language_id:
if language_code is not None and self.language != language_code:
try:
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id],
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
]
self.language = language_id
self.language = language_code
except KeyError:
console.print(
f"[red]Language {language_id} not supported by Melo. Using {self.language} instead."
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment