Skip to content
Snippets Groups Projects
Commit 669bdbf9 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

catch a few exceptions from melo

parent 3fff1d1d
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from baseHandler import BaseHandler from baseHandler import BaseHandler
from rich.console import Console from rich.console import Console
import logging import logging
from shared_variables import current_language
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
console = Console() console = Console()
...@@ -109,10 +108,8 @@ class WhisperSTTHandler(BaseHandler): ...@@ -109,10 +108,8 @@ class WhisperSTTHandler(BaseHandler):
)[0] )[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
print("WHISPER curr lang", language_id)
logger.debug("finished whisper inference") logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
console.print(f"[red]Language ID Whisper: {language_id}") logger.debug(f"Language ID Whisper: {language_id}")
yield (pred_text, language_id) yield (pred_text, language_id)
...@@ -5,7 +5,6 @@ import librosa ...@@ -5,7 +5,6 @@ import librosa
import numpy as np import numpy as np
from rich.console import Console from rich.console import Console
import torch import torch
from shared_variables import current_language
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,9 +41,15 @@ class MeloTTSHandler(BaseHandler): ...@@ -42,9 +41,15 @@ class MeloTTSHandler(BaseHandler):
): ):
self.should_listen = should_listen self.should_listen = should_listen
self.device = device self.device = device
self.language = "<|" + language + "|>" # 'Tokenize' the language code to do less operations self.language = (
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device) "<|" + language + "|>"
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]] ) # 'Tokenize' the language code to do less operations
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]
]
self.blocksize = blocksize self.blocksize = blocksize
self.warmup() self.warmup()
...@@ -56,18 +61,24 @@ class MeloTTSHandler(BaseHandler): ...@@ -56,18 +61,24 @@ class MeloTTSHandler(BaseHandler):
language_id = None language_id = None
if isinstance(llm_sentence, tuple): if isinstance(llm_sentence, tuple):
print("llm sentence is tuple!")
llm_sentence, language_id = llm_sentence llm_sentence, language_id = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}") console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_id is not None and self.language != language_id: if language_id is not None and self.language != language_id:
try: try:
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id], device=self.device) self.model = TTS(
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]] language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]
]
self.language = language_id self.language = language_id
except KeyError: except KeyError:
console.print(f"[red]Language {language_id} not supported by Melo. Using {self.language} instead.") console.print(
f"[red]Language {language_id} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps": if self.device == "mps":
import time import time
...@@ -79,7 +90,13 @@ class MeloTTSHandler(BaseHandler): ...@@ -79,7 +90,13 @@ class MeloTTSHandler(BaseHandler):
time.time() - start time.time() - start
) # Removing this line makes it fail more often. I'm looking into it. ) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) try:
audio_chunk = self.model.tts_to_file(
llm_sentence, self.speaker_id, quiet=True
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0: if len(audio_chunk) == 0:
self.should_listen.set() self.should_listen.set()
return return
......
...@@ -49,7 +49,6 @@ console = Console() ...@@ -49,7 +49,6 @@ console = Console()
logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
def prepare_args(args, prefix): def prepare_args(args, prefix):
""" """
Rename arguments by removing the prefix and prepares the gen_kwargs. Rename arguments by removing the prefix and prepares the gen_kwargs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment