Skip to content
Snippets Groups Projects
Commit 669bdbf9 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

catch a few exceptions from melo

parent 3fff1d1d
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,6 @@ import torch
from baseHandler import BaseHandler
from rich.console import Console
import logging
from shared_variables import current_language
logger = logging.getLogger(__name__)
console = Console()
......@@ -109,10 +108,8 @@ class WhisperSTTHandler(BaseHandler):
)[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
print("WHISPER curr lang", language_id)
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
console.print(f"[red]Language ID Whisper: {language_id}")
logger.debug(f"Language ID Whisper: {language_id}")
yield (pred_text, language_id)
......@@ -5,7 +5,6 @@ import librosa
import numpy as np
from rich.console import Console
import torch
from shared_variables import current_language
logger = logging.getLogger(__name__)
......@@ -42,9 +41,15 @@ class MeloTTSHandler(BaseHandler):
):
self.should_listen = should_listen
self.device = device
self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]]
self.language = (
"<|" + language + "|>"
) # 'Tokenize' the language code to do less operations
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]
]
self.blocksize = blocksize
self.warmup()
......@@ -56,18 +61,24 @@ class MeloTTSHandler(BaseHandler):
language_id = None
if isinstance(llm_sentence, tuple):
print("llm sentence is tuple!")
llm_sentence, language_id = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_id is not None and self.language != language_id:
try:
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id], device=self.device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]]
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]
]
self.language = language_id
except KeyError:
console.print(f"[red]Language {language_id} not supported by Melo. Using {self.language} instead.")
console.print(
f"[red]Language {language_id} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":
import time
......@@ -79,7 +90,13 @@ class MeloTTSHandler(BaseHandler):
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
try:
audio_chunk = self.model.tts_to_file(
llm_sentence, self.speaker_id, quiet=True
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0:
self.should_listen.set()
return
......
......@@ -49,7 +49,6 @@ console = Console()
logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
def prepare_args(args, prefix):
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment