diff --git a/LLM/language_model.py b/LLM/language_model.py index 1a957625b43193f936898ea68e4b48d810e86101..ddeb34b1e6895a6ffd77a9f734cb17ad50a1c3a0 100644 --- a/LLM/language_model.py +++ b/LLM/language_model.py @@ -115,7 +115,9 @@ class LanguageModelHandler(BaseHandler): language_code = None if isinstance(prompt, tuple): prompt, language_code = prompt - prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + if language_code[-5:] == "-auto": + language_code = language_code[:-5] + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) thread = Thread( diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 82de10214299d4ebe68200256cec73f22974fd02..87812c53ed9686e4e8cb7c39d657965fb92a950d 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -73,7 +73,9 @@ class MLXLanguageModelHandler(BaseHandler): if isinstance(prompt, tuple): prompt, language_code = prompt - prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + if language_code[-5:] == "-auto": + language_code = language_code[:-5] + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) diff --git a/LLM/openai_api_language_model.py b/LLM/openai_api_language_model.py index 4614d386b183c4192aeafbb857f2bfad15bf413f..dcbabe0a50879eba7b5a59857d168c11507362aa 100644 --- a/LLM/openai_api_language_model.py +++ b/LLM/openai_api_language_model.py @@ -1,13 +1,25 @@ -from openai import OpenAI -from LLM.chat import Chat -from baseHandler import BaseHandler -from rich.console import Console import logging import time + +from nltk import sent_tokenize +from rich.console import Console +from openai import OpenAI + +from baseHandler import BaseHandler +from LLM.chat import Chat + logger = logging.getLogger(__name__) console = Console() -from nltk import sent_tokenize + +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} class OpenApiModelHandler(BaseHandler): """ @@ -61,7 +73,10 @@ class OpenApiModelHandler(BaseHandler): language_code = None if isinstance(prompt, tuple): prompt, language_code = prompt - + if language_code[-5:] == "-auto": + language_code = language_code[:-5] + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + response = self.client.chat.completions.create( model=self.model_name, messages=[ diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 6f9fbb217bb32b227096e4594f6e648fa879d37a..53b6b5a035ec9af062b8f21c562a7a9e6b7e3846 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -82,4 +82,7 @@ class LightningWhisperSTTHandler(BaseHandler): console.print(f"[yellow]USER: {pred_text}") logger.debug(f"Language Code Whisper: {language_code}") + if self.start_language == "auto": + language_code += "-auto" + yield (pred_text, language_code) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 06cf613b01156aadcf0ccee9533c0d7aff039930..09300879e1dea3f790e3db40349b1da2b9675888 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -40,9 +40,8 @@ class WhisperSTTHandler(BaseHandler): self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs - if language == 'auto': - language = None - self.last_language = language + self.start_language = language + self.last_language = language if language != "auto" else None if self.last_language is not None: self.gen_kwargs["language"] = self.last_language @@ -137,4 +136,7 @@ class WhisperSTTHandler(BaseHandler): console.print(f"[yellow]USER: {pred_text}") logger.debug(f"Language Code Whisper: {language_code}") + if self.start_language == "auto": + language_code += "-auto" + yield (pred_text, language_code)