diff --git a/LLM/language_model.py b/LLM/language_model.py index 8acfdc84f919cea75fd363a0e79b991845205baa..6b480176c991a30fb51ef9ec72784395ffe0d45c 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -18,6 +18,15 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "<|en|>": "english", + "<|fr|>": "french", + "<|es|>": "spanish", + "<|zh|>": "chinese", + "<|ja|>": "japanese", + "<|ko|>": "korean", +} + class LanguageModelHandler(BaseHandler): """ Handles the language model part. @@ -106,6 +115,7 @@ class LanguageModelHandler(BaseHandler): language_id = None if isinstance(prompt, tuple): prompt, language_id = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( @@ -125,7 +135,7 @@ class LanguageModelHandler(BaseHandler): printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: - yield (sentences[0]) + yield (sentences[0], language_id) printable_text = new_text self.chat.append({"role": "assistant", "content": generated_text}) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 38dc5910770a15393d93e7ca03a52a83c21d74f3..e3c99bd2310c97d2eded7ba4bce5427ebe3a3e4a 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -1,10 +1,10 @@ from time import perf_counter from transformers import ( - AutoModelForSpeechSeq2Seq, AutoProcessor, + AutoModelForSpeechSeq2Seq ) import torch - +from copy import copy from baseHandler import BaseHandler from rich.console import Console import logging @@ -12,6 +12,15 @@ import logging logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "<|en|>", + "<|fr|>", + "<|es|>", + "<|zh|>", + "<|ja|>", + "<|ko|>", +] + class WhisperSTTHandler(BaseHandler): """ @@ -24,13 +33,16 @@ class WhisperSTTHandler(BaseHandler): device="cuda", torch_dtype="float16", compile_mode=None, + language=None, gen_kwargs={}, ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs - del self.gen_kwargs["language"] + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -103,6 +115,17 @@ class WhisperSTTHandler(BaseHandler): input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) + language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + + if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_id) + gen_kwargs = copy(self.gen_kwargs) + gen_kwargs['language'] = self.last_language + language_id = self.last_language + pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_id + pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] diff --git a/shared_variables.py b/shared_variables.py deleted file mode 100644 index 649f0f30f718c7210580681832f49643872caab5..0000000000000000000000000000000000000000 --- a/shared_variables.py +++ /dev/null @@ -1 +0,0 @@ -current_language = "en"