diff --git a/LLM/mlx_language_model.py b/LLM/mlx_language_model.py index 7e041de3e8bcb82ed9332eb778c6b03884f010b0..ae11b35e7e99e608cea493a5184d18734e2a54eb 100644 --- a/LLM/mlx_language_model.py +++ b/LLM/mlx_language_model.py @@ -9,6 +9,14 @@ logger = logging.getLogger(__name__) console = Console() +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", +} class MLXLanguageModelHandler(BaseHandler): """ @@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler): def process(self, prompt): logger.debug("infering language model...") + language_code = None + + if isinstance(prompt, tuple): + prompt, language_code = prompt + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt self.chat.append({"role": self.user_role, "content": prompt}) @@ -86,9 +99,9 @@ class MLXLanguageModelHandler(BaseHandler): output += t curr_output += t if curr_output.endswith((".", "?", "!", "<|end|>")): - yield curr_output.replace("<|end|>", "") + yield (curr_output.replace("<|end|>", ""), language_code) curr_output = "" generated_text = output.replace("<|end|>", "") torch.mps.empty_cache() - self.chat.append({"role": "assistant", "content": generated_text}) + self.chat.append({"role": "assistant", "content": generated_text}) \ No newline at end of file diff --git a/README.md b/README.md index 93c5c6cc2b2cd29695d5018cb641c2b68ce7c6f7..fde24ccfdcb3eb1f4db985f86c2da1bbfa0d39b0 100644 --- a/README.md +++ b/README.md @@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install ### Server/Client Approach -To run the pipeline on the server: -```bash -python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 -``` +1. Run the pipeline on the server: + ```bash + python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 + ``` -Then run the client locally to handle sending microphone input and receiving generated audio: -```bash -python listen_and_play.py --host <IP address of your server> -``` +2. Run the client locally to handle microphone input and receive generated audio: + ```bash + python listen_and_play.py --host <IP address of your server> + ``` -### Local approach (running on Mac) -To run on mac, we recommend setting the flag `--local_mac_optimal_settings`: -```bash -python s2s_pipeline.py --local_mac_optimal_settings -``` +### Local Approach (Mac) + +1. For optimal settings on Mac: + ```bash + python s2s_pipeline.py --local_mac_optimal_settings + ``` -You can also pass `--device mps` to have all the models set to device mps. -The local mac optimal settings set the mode to be local as explained above and change the models to: -- LightningWhisperMLX -- MLX LM -- MeloTTS +This setting: + - Adds `--device mps` to use MPS for all models. + - Sets LightningWhisperMLX for STT + - Sets MLX LM for language model + - Sets MeloTTS for TTS ### Recommended usage with Cuda @@ -117,6 +118,57 @@ python s2s_pipeline.py \ For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`). + +### Multi-language Support + +The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups: + +#### With the server version: + + +For automatic language detection: + +```bash +python s2s_pipeline.py \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ +``` + +Or for one language in particular, chinese in this example + +```bash +python s2s_pipeline.py \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ +``` + +#### Local Mac Setup + +For automatic language detection: + +```bash +python s2s_pipeline.py \ + --local_mac_optimal_settings \ + --device mps \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ +``` + +Or for one language in particular, chinese in this example + +```bash +python s2s_pipeline.py \ + --local_mac_optimal_settings \ + --device mps \ + --stt_model_name large-v3 \ + --language zh \ + --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ +``` + + ## Command-line Usage ### Model Parameters diff --git a/STT/lightning_whisper_mlx_handler.py b/STT/lightning_whisper_mlx_handler.py index 4785b73853275e5a308ae9da48667b0b656297df..6f9fbb217bb32b227096e4594f6e648fa879d37a 100644 --- a/STT/lightning_whisper_mlx_handler.py +++ b/STT/lightning_whisper_mlx_handler.py @@ -4,12 +4,22 @@ from baseHandler import BaseHandler from lightning_whisper_mlx import LightningWhisperMLX import numpy as np from rich.console import Console +from copy import copy import torch logger = logging.getLogger(__name__) console = Console() +SUPPORTED_LANGUAGES = [ + "en", + "fr", + "es", + "zh", + "ja", + "ko", +] + class LightningWhisperSTTHandler(BaseHandler): """ @@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler): def setup( self, model_name="distil-large-v3", - device="cuda", + device="mps", torch_dtype="float16", compile_mode=None, language=None, @@ -29,6 +39,9 @@ class LightningWhisperSTTHandler(BaseHandler): model_name = model_name.split("/")[-1] self.device = device self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) + self.start_language = language + self.last_language = language + self.warmup() def warmup(self): @@ -47,10 +60,26 @@ class LightningWhisperSTTHandler(BaseHandler): global pipeline_start pipeline_start = perf_counter() - pred_text = self.model.transcribe(spoken_prompt)["text"].strip() + if self.start_language != 'auto': + transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) + else: + transcription_dict = self.model.transcribe(spoken_prompt) + language_code = transcription_dict["language"] + if language_code not in SUPPORTED_LANGUAGES: + logger.warning(f"Whisper detected unsupported language: {language_code}") + if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language + transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) + else: + transcription_dict = {"text": "", "language": "en"} + else: + self.last_language = language_code + + pred_text = transcription_dict["text"].strip() + language_code = transcription_dict["language"] torch.mps.empty_cache() logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") + logger.debug(f"Language Code Whisper: {language_code}") - yield pred_text + yield (pred_text, language_code) diff --git a/VAD/vad_handler.py b/VAD/vad_handler.py index 1dc64008f142cac08a3ca1cb7ed068f00a221097..3f5c6acbe6dde672a5259f57a019b36785e14ed3 100644 --- a/VAD/vad_handler.py +++ b/VAD/vad_handler.py @@ -86,3 +86,7 @@ class VADHandler(BaseHandler): ) array = enhanced.numpy().squeeze() yield array + + @property + def min_time_to_debug(self): + return 0.00001 diff --git a/baseHandler.py b/baseHandler.py index 6f5efa80adb473cff0b50336ddbc744fe50ba27c..61532e4705ed7efee02c4c2b4079a3c68526bfde 100644 --- a/baseHandler.py +++ b/baseHandler.py @@ -36,7 +36,8 @@ class BaseHandler: start_time = perf_counter() for output in self.process(input): self._times.append(perf_counter() - start_time) - logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") + if self.last_time > self.min_time_to_debug: + logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") self.queue_out.put(output) start_time = perf_counter() @@ -46,6 +47,10 @@ class BaseHandler: @property def last_time(self): return self._times[-1] + + @property + def min_time_to_debug(self): + return 0.001 def cleanup(self): pass