From 77894a7a5b7cd2e51b87b2a945688c020c15d566 Mon Sep 17 00:00:00 2001 From: andimarafioti <andimarafioti@gmail.com> Date: Fri, 30 Aug 2024 09:38:44 +0000 Subject: [PATCH] working --- LLM/language_model.py | 12 +++++++++++- STT/whisper_stt_handler.py | 29 ++++++++++++++++++++++++++--- shared_variables.py | 1 - 3 files changed, 37 insertions(+), 5 deletions(-) delete mode 100644 shared_variables.py diff --git a/LLM/language_model.py b/LLM/language_model.py index 8acfdc8..6b48017 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -18,6 +18,15 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "<|en|>": "english", + "<|fr|>": "french", + "<|es|>": "spanish", + "<|zh|>": "chinese", + "<|ja|>": "japanese", + "<|ko|>": "korean", +} + class LanguageModelHandler(BaseHandler): """ Handles the language model part. @@ -106,6 +115,7 @@ class LanguageModelHandler(BaseHandler): language_id = None if isinstance(prompt, tuple): prompt, language_id = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( @@ -125,7 +135,7 @@ class LanguageModelHandler(BaseHandler): printable_text += new_text sentences = sent_tokenize(printable_text) if len(sentences) > 1: - yield (sentences[0]) + yield (sentences[0], language_id) printable_text = new_text self.chat.append({"role": "assistant", "content": generated_text}) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 38dc591..e3c99bd 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -1,10 +1,10 @@ from time import perf_counter from transformers import ( - AutoModelForSpeechSeq2Seq, AutoProcessor, + AutoModelForSpeechSeq2Seq ) import torch - +from copy import copy from baseHandler import BaseHandler from rich.console import Console import logging @@ -12,6 +12,15 @@ import logging logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "<|en|>", + "<|fr|>", + "<|es|>", + "<|zh|>", + "<|ja|>", + "<|ko|>", +] + class WhisperSTTHandler(BaseHandler): """ @@ -24,13 +33,16 @@ class WhisperSTTHandler(BaseHandler): device="cuda", torch_dtype="float16", compile_mode=None, + language=None, gen_kwargs={}, ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs - del self.gen_kwargs["language"] + self.last_language = language + if self.last_language is not None: + self.gen_kwargs["language"] = self.last_language self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -103,6 +115,17 @@ class WhisperSTTHandler(BaseHandler): input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) + language_id = self.processor.tokenizer.decode(pred_ids[0, 1]) + + if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language + logger.warning("Whisper detected unsupported language:", language_id) + gen_kwargs = copy(self.gen_kwargs) + gen_kwargs['language'] = self.last_language + language_id = self.last_language + pred_ids = self.model.generate(input_features, **gen_kwargs) + else: + self.last_language = language_id + pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] diff --git a/shared_variables.py b/shared_variables.py deleted file mode 100644 index 649f0f3..0000000 --- a/shared_variables.py +++ /dev/null @@ -1 +0,0 @@ -current_language = "en" -- GitLab