diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 1470bfbdbf0a307d4128a556190d3d95e84fd7e7..6aa165c7746ecc7a239125c4af1fe69c70ec8d63 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler): # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens + "min_new_tokens": self.gen_kwargs[ + "max_new_tokens" + ], # Yes, assign max_new_tokens to min_new_tokens "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..6bdc6bfe31929c930726df594fdf296ab3e21ce7 --- /dev/null +++ b/TTS/chatTTS_handler.py @@ -0,0 +1,82 @@ +import ChatTTS +import logging +from baseHandler import BaseHandler +import librosa +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ChatTTSHandler(BaseHandler): + def setup( + self, + should_listen, + device="cuda", + gen_kwargs={}, # Unused + stream=True, + chunk_size=512, + ): + self.should_listen = should_listen + self.device = device + self.model = ChatTTS.Chat() + self.model.load(compile=False) # Doesn't work for me with True + self.chunk_size = chunk_size + self.stream = stream + rnd_spk_emb = self.model.sample_random_speaker() + self.params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb=rnd_spk_emb, + ) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + _ = self.model.infer("text") + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + if self.device == "mps": + import time + + start = time.time() + torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. + torch.mps.empty_cache() # Frees all memory allocated by the MPS device. + _ = ( + time.time() - start + ) # Removing this line makes it fail more often. I'm looking into it. + + wavs_gen = self.model.infer( + llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream + ) + + if self.stream: + wavs = [np.array([])] + for gen in wavs_gen: + if gen[0] is None or len(gen[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] + while len(audio_chunk) > self.chunk_size: + yield audio_chunk[: self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® + audio_chunk = audio_chunk[self.chunk_size :] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® + yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) + else: + wavs = wavs_gen + if len(wavs[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + for i in range(0, len(audio_chunk), self.chunk_size): + yield np.pad( + audio_chunk[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), + ) + self.should_listen.set() diff --git a/arguments_classes/chat_tts_arguments.py b/arguments_classes/chat_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..bccce27176a4e2e818a2285ebdfa2c2cd63d69c9 --- /dev/null +++ b/arguments_classes/chat_tts_arguments.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field + + +@dataclass +class ChatTTSHandlerArguments: + chat_tts_stream: bool = field( + default=True, + metadata={"help": "The tts mode is stream Default is 'stream'."}, + ) + chat_tts_device: str = field( + default="cuda", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'cuda'." + }, + ) + chat_tts_chunk_size: int = field( + default=512, + metadata={ + "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。." + }, + ) diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index df9d94286965d23a75e2f71d5594f99aeb9148fe..8bf4884e54c7a55a0ba783c9801d4a5b88026c56 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -35,7 +35,7 @@ class ModuleArguments: tts: Optional[str] = field( default="parler", metadata={ - "help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" + "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" }, ) log_level: str = field( diff --git a/requirements.txt b/requirements.txt index 71f85b7185cc7bba07722844083ab89422b154ac..fba30cd7f5e716797d29b3dd5890fd1a610d06a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 sounddevice==0.5.0 -funasr -modelscope -deepfilternet +ChatTTS>=0.1.1 +funasr>=1.1.6 +modelscope>=1.17.1 +deepfilternet>=0.5.6 diff --git a/requirements_mac.txt b/requirements_mac.txt index 7dcd4b0ed3887393edbc038c30fc166edd645681..4a1c5cbb4a101ce611a2b81e4d52b73259782a0c 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -5,6 +5,8 @@ torch==2.4.0 sounddevice==0.5.0 lightning-whisper-mlx>=0.0.10 mlx-lm>=0.14.0 +ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 -deepfilternet +deepfilternet>=0.5.6 + diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 1231abd824efe190d8036d78c7a508f058f6edb3..8da829834e85c856458a571bf3c7242500d8ae6b 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -8,6 +8,7 @@ from threading import Event from typing import Optional from sys import platform from VAD.vad_handler import VADHandler +from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.mlx_language_model_arguments import ( MLXLanguageModelHandlerArguments, @@ -79,6 +80,7 @@ def main(): MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, + ChatTTSHandlerArguments, ) ) @@ -96,6 +98,7 @@ def main(): mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, + chat_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided @@ -110,6 +113,7 @@ def main(): mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, + chat_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() # 1. Handle logger @@ -186,6 +190,7 @@ def main(): prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") + prepare_args(chat_tts_handler_kwargs, "chat_tts") # 3. Build the pipeline stop_event = Event() @@ -310,8 +315,21 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(melo_tts_handler_kwargs), ) + elif module_kwargs.tts == "chatTTS": + try: + from TTS.chatTTS_handler import ChatTTSHandler + except RuntimeError as e: + logger.error("Error importing ChatTTSHandler") + raise e + tts = ChatTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(chat_tts_handler_kwargs), + ) else: - raise ValueError("The TTS should be either parler or melo") + raise ValueError("The TTS should be either parler, melo or chatTTS") # 4. Run the pipeline try: