Skip to content
Snippets Groups Projects
Commit cfd3065e authored by Andres Marafioti's avatar Andres Marafioti
Browse files

language fun

parent d2d33d10
No related branches found
No related tags found
No related merge requests found
......@@ -105,6 +105,12 @@ class WhisperSTTHandler(BaseHandler):
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
for char in "<>|":
language_id = language_id.replace(char, "") # remove special tokens
global current_language
current_language = language_id
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
......
......@@ -10,21 +10,40 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
class MeloTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
language="EN_NEWEST",
speaker_to_id="EN-Newest",
language="en",
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.model = TTS(language=language, device=device)
self.speaker_id = self.model.hps.data.spk2id[speaker_to_id]
self.language = language
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language], device=device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]]
self.blocksize = blocksize
self.warmup()
......@@ -34,6 +53,11 @@ class MeloTTSHandler(BaseHandler):
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
global current_language
if self.language != current_language:
self.model = TTS(language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=self.device)
self.speaker_id = self.model.hps.data.spk2id[WHISPER_LANGUAGE_TO_MELO_SPEAKER[self.language]]
if self.device == "mps":
import time
......
......@@ -48,6 +48,8 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
console = Console()
logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
current_language = "en"
def prepare_args(args, prefix):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment