Skip to content
Snippets Groups Projects
Commit 77894a7a authored by andimarafioti's avatar andimarafioti Committed by Andres Marafioti
Browse files

working

parent 669bdbf9
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"<|en|>": "english",
"<|fr|>": "french",
"<|es|>": "spanish",
"<|zh|>": "chinese",
"<|ja|>": "japanese",
"<|ko|>": "korean",
}
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
......@@ -106,6 +115,7 @@ class LanguageModelHandler(BaseHandler):
language_id = None
if isinstance(prompt, tuple):
prompt, language_id = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
......@@ -125,7 +135,7 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
yield (sentences[0], language_id)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
......
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
import torch
from copy import copy
from baseHandler import BaseHandler
from rich.console import Console
import logging
......@@ -12,6 +12,15 @@ import logging
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"<|en|>",
"<|fr|>",
"<|es|>",
"<|zh|>",
"<|ja|>",
"<|ko|>",
]
class WhisperSTTHandler(BaseHandler):
"""
......@@ -24,13 +33,16 @@ class WhisperSTTHandler(BaseHandler):
device="cuda",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
del self.gen_kwargs["language"]
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
......@@ -103,6 +115,17 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_id)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_id = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_id
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
......
current_language = "en"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment