Skip to content
Snippets Groups Projects
Unverified Commit b915e58b authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge pull request #60 from huggingface/multi-language

Add support for multiple languages
parents d2d33d10 4e6055f1
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
......@@ -69,7 +78,7 @@ class LanguageModelHandler(BaseHandler):
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_input_text = "Repeat the word 'home'."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
......@@ -103,6 +112,10 @@ class LanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
......@@ -122,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
yield (sentences[0], language_code)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
yield (printable_text, language_code)
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
import torch
from copy import copy
from baseHandler import BaseHandler
from rich.console import Console
import logging
......@@ -12,6 +12,15 @@ import logging
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class WhisperSTTHandler(BaseHandler):
"""
......@@ -24,12 +33,18 @@ class WhisperSTTHandler(BaseHandler):
device="cuda",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
if language == 'auto':
language = None
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
......@@ -102,11 +117,24 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_code = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_code
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)
......@@ -10,21 +10,44 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
class MeloTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
language="EN_NEWEST",
speaker_to_id="EN-Newest",
language="en",
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.model = TTS(language=language, device=device)
self.speaker_id = self.model.hps.data.spk2id[speaker_to_id]
self.language = language
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize
self.warmup()
......@@ -33,7 +56,28 @@ class MeloTTSHandler(BaseHandler):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
language_code = None
if isinstance(llm_sentence, tuple):
llm_sentence, language_code = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_code is not None and self.language != language_code:
try:
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
]
self.language = language_code
except KeyError:
console.print(
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":
import time
......@@ -44,7 +88,13 @@ class MeloTTSHandler(BaseHandler):
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
try:
audio_chunk = self.model.tts_to_file(
llm_sentence, self.speaker_id, quiet=True
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0:
self.should_listen.set()
return
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
@dataclass
class MeloTTSHandlerArguments:
melo_language: str = field(
default="EN_NEWEST",
default="en",
metadata={
"help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
},
......@@ -16,7 +16,7 @@ class MeloTTSHandlerArguments:
},
)
melo_speaker_to_id: str = field(
default="EN-Newest",
default="en",
metadata={
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
},
......
from dataclasses import dataclass, field
from typing import Optional
@dataclass
......@@ -51,9 +52,13 @@ class WhisperSTTHandlerArguments:
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
},
)
stt_gen_language: str = field(
default="en",
language: Optional[str] = field(
default='en',
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
"help": """The language for the conversation.
Choose between 'en' (english), 'fr' (french), 'es' (spanish),
'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'.
If using 'auto', the language is automatically detected and can
change during the conversation. Default is 'en'."""
},
)
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment