Skip to content
Snippets Groups Projects
Commit 129cd11b authored by Andres Marafioti's avatar Andres Marafioti
Browse files

Fixes to bugs from original PR

parent 1d3b5bfc
No related branches found
No related tags found
No related merge requests found
......@@ -99,13 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
printable_text = generated_text
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield (printable_text, language_code)
\ No newline at end of file
self.chat.append({"role": "assistant", "content": generated_text})
\ No newline at end of file
......@@ -39,11 +39,8 @@ class LightningWhisperSTTHandler(BaseHandler):
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
if language == 'auto':
language = None
self.start_language = language
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.warmup()
......@@ -63,25 +60,24 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start
pipeline_start = perf_counter()
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
language_code = self.model.transcribe(spoken_prompt)["language"]
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs) #####
gen_kwargs['language'] = self.last_language
language_code = self.last_language
# pred_ids = self.model.generate(input_features, **gen_kwargs)
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
self.last_language = language_code
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code
pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
language_code = self.model.transcribe(spoken_prompt)["language"]
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
......
......@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN",
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
......@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-BR",
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment