Skip to content
Snippets Groups Projects
Unverified Commit 32ae906c authored by eustlb's avatar eustlb Committed by GitHub
Browse files

Merge pull request #112 from eustlb/improve-auto-language

Improve auto language
parents d5e46072 55ba279a
No related branches found
No related tags found
No related merge requests found
...@@ -115,7 +115,9 @@ class LanguageModelHandler(BaseHandler): ...@@ -115,7 +115,9 @@ class LanguageModelHandler(BaseHandler):
language_code = None language_code = None
if isinstance(prompt, tuple): if isinstance(prompt, tuple):
prompt, language_code = prompt prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt if language_code[-5:] == "-auto":
language_code = language_code[:-5]
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread( thread = Thread(
......
...@@ -73,7 +73,9 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -73,7 +73,9 @@ class MLXLanguageModelHandler(BaseHandler):
if isinstance(prompt, tuple): if isinstance(prompt, tuple):
prompt, language_code = prompt prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt if language_code[-5:] == "-auto":
language_code = language_code[:-5]
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
......
from openai import OpenAI
from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging import logging
import time import time
from nltk import sent_tokenize
from rich.console import Console
from openai import OpenAI
from baseHandler import BaseHandler
from LLM.chat import Chat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
console = Console() console = Console()
from nltk import sent_tokenize
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class OpenApiModelHandler(BaseHandler): class OpenApiModelHandler(BaseHandler):
""" """
...@@ -61,7 +73,10 @@ class OpenApiModelHandler(BaseHandler): ...@@ -61,7 +73,10 @@ class OpenApiModelHandler(BaseHandler):
language_code = None language_code = None
if isinstance(prompt, tuple): if isinstance(prompt, tuple):
prompt, language_code = prompt prompt, language_code = prompt
if language_code[-5:] == "-auto":
language_code = language_code[:-5]
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=[ messages=[
......
...@@ -82,4 +82,7 @@ class LightningWhisperSTTHandler(BaseHandler): ...@@ -82,4 +82,7 @@ class LightningWhisperSTTHandler(BaseHandler):
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}") logger.debug(f"Language Code Whisper: {language_code}")
if self.start_language == "auto":
language_code += "-auto"
yield (pred_text, language_code) yield (pred_text, language_code)
...@@ -40,9 +40,8 @@ class WhisperSTTHandler(BaseHandler): ...@@ -40,9 +40,8 @@ class WhisperSTTHandler(BaseHandler):
self.torch_dtype = getattr(torch, torch_dtype) self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs self.gen_kwargs = gen_kwargs
if language == 'auto': self.start_language = language
language = None self.last_language = language if language != "auto" else None
self.last_language = language
if self.last_language is not None: if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language self.gen_kwargs["language"] = self.last_language
...@@ -137,4 +136,7 @@ class WhisperSTTHandler(BaseHandler): ...@@ -137,4 +136,7 @@ class WhisperSTTHandler(BaseHandler):
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}") logger.debug(f"Language Code Whisper: {language_code}")
if self.start_language == "auto":
language_code += "-auto"
yield (pred_text, language_code) yield (pred_text, language_code)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment