Skip to content
Snippets Groups Projects
Commit 29fedd07 authored by Elli0t's avatar Elli0t
Browse files

Update: Added multi-language support for macOS

parent d98e2527
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__) ...@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
console = Console() console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class MLXLanguageModelHandler(BaseHandler): class MLXLanguageModelHandler(BaseHandler):
""" """
...@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
def process(self, prompt): def process(self, prompt):
logger.debug("infering language model...") logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
...@@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler):
yield curr_output.replace("<|end|>", "") yield curr_output.replace("<|end|>", "")
curr_output = "" curr_output = ""
generated_text = output.replace("<|end|>", "") generated_text = output.replace("<|end|>", "")
printable_text = generated_text
torch.mps.empty_cache() torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text}) self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield (printable_text, language_code)
\ No newline at end of file
...@@ -4,12 +4,22 @@ from baseHandler import BaseHandler ...@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
from lightning_whisper_mlx import LightningWhisperMLX from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np import numpy as np
from rich.console import Console from rich.console import Console
from copy import copy
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
console = Console() console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class LightningWhisperSTTHandler(BaseHandler): class LightningWhisperSTTHandler(BaseHandler):
""" """
...@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler): ...@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup( def setup(
self, self,
model_name="distil-large-v3", model_name="distil-large-v3",
device="cuda", device="mps",
torch_dtype="float16", torch_dtype="float16",
compile_mode=None, compile_mode=None,
language=None, language=None,
...@@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler): ...@@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler):
model_name = model_name.split("/")[-1] model_name = model_name.split("/")[-1]
self.device = device self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
if language == 'auto':
language = None
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.warmup() self.warmup()
def warmup(self): def warmup(self):
...@@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler): ...@@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start global pipeline_start
pipeline_start = perf_counter() pipeline_start = perf_counter()
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
language_code = self.model.transcribe(spoken_prompt)["language"]
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs) #####
gen_kwargs['language'] = self.last_language
language_code = self.last_language
# pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_code
pred_text = self.model.transcribe(spoken_prompt)["text"].strip() pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
torch.mps.empty_cache() torch.mps.empty_cache()
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
language_code = self.model.transcribe(spoken_prompt)["language"]
logger.debug("finished whisper inference") logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text yield (pred_text, language_code)
...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) ...@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
console = Console() console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST", "en": "EN",
"fr": "FR", "fr": "FR",
"es": "ES", "es": "ES",
"zh": "ZH", "zh": "ZH",
...@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { ...@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
} }
WHISPER_LANGUAGE_TO_MELO_SPEAKER = { WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest", "en": "EN-BR",
"fr": "FR", "fr": "FR",
"es": "ES", "es": "ES",
"zh": "ZH", "zh": "ZH",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment