Skip to content
Snippets Groups Projects
Commit 55ba279a authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

fix whisper trfms handle auto language

parent c0eeceb9
Branches
Tags
No related merge requests found
...@@ -83,6 +83,6 @@ class LightningWhisperSTTHandler(BaseHandler): ...@@ -83,6 +83,6 @@ class LightningWhisperSTTHandler(BaseHandler):
logger.debug(f"Language Code Whisper: {language_code}") logger.debug(f"Language Code Whisper: {language_code}")
if self.start_language == "auto": if self.start_language == "auto":
language_code += "-auto" language_code += "-auto"
yield (pred_text, language_code) yield (pred_text, language_code)
...@@ -40,9 +40,8 @@ class WhisperSTTHandler(BaseHandler): ...@@ -40,9 +40,8 @@ class WhisperSTTHandler(BaseHandler):
self.torch_dtype = getattr(torch, torch_dtype) self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs self.gen_kwargs = gen_kwargs
if language == 'auto': self.start_language = language
language = None self.last_language = language if language != "auto" else None
self.last_language = language
if self.last_language is not None: if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language self.gen_kwargs["language"] = self.last_language
...@@ -137,7 +136,7 @@ class WhisperSTTHandler(BaseHandler): ...@@ -137,7 +136,7 @@ class WhisperSTTHandler(BaseHandler):
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}") logger.debug(f"Language Code Whisper: {language_code}")
if self.language is None: if self.start_language == "auto":
language_code += "-auto" language_code += "-auto"
yield (pred_text, language_code) yield (pred_text, language_code)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment