diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0c171ae2010f6ebf321081a3696c619229d41289 --- /dev/null +++ b/TTS/chatTTS_handler.py @@ -0,0 +1,87 @@ +import ChatTTS +import logging +from baseHandler import BaseHandler +import librosa +import numpy as np +from rich.console import Console +import torch + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + + +class ChatTTSHandler(BaseHandler): + def setup( + self, + should_listen, + device="mps", + gen_kwargs={}, # Unused + stream=True, + chunk_size=512, + ): + self.should_listen = should_listen + self.device = device + self.model = ChatTTS.Chat() + self.model.load(compile=True) # Set to True for better performance + self.chunk_size = chunk_size + self.stream = stream + rnd_spk_emb = self.model.sample_random_speaker() + self.params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb=rnd_spk_emb, + ) + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + _= self.model.infer("text") + + + def process(self, llm_sentence): + console.print(f"[green]ASSISTANT: {llm_sentence}") + if self.device == "mps": + import time + + start = time.time() + torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. + torch.mps.empty_cache() # Frees all memory allocated by the MPS device. + _ = ( + time.time() - start + ) # Removing this line makes it fail more often. I'm looking into it. + + wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream) + + if self.stream: + wavs = [np.array([])] + for gen in wavs_gen: + print('new chunk gen', len(gen[0])) + if len(gen[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + print('audio_chunk:', audio_chunk.shape) + while len(audio_chunk) > self.chunk_size: + yield audio_chunk[:self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® + audio_chunk = audio_chunk[self.chunk_size:] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® + yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk))) + else: + print('check result', wavs_gen) + wavs = wavs_gen + if len(wavs[0]) == 0: + self.should_listen.set() + return + audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) + audio_chunk = (audio_chunk * 32768).astype(np.int16) + print('audio_chunk:', audio_chunk.shape) + for i in range(0, len(audio_chunk), self.chunk_size): + yield np.pad( + audio_chunk[i : i + self.chunk_size], + (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), + ) + self.should_listen.set() + + diff --git a/arguments_classes/chat_tts_arguments.py b/arguments_classes/chat_tts_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..096c1a2a176e2e8e350983e4f28adb04e121e035 --- /dev/null +++ b/arguments_classes/chat_tts_arguments.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field + + +@dataclass +class ChatTTSHandlerArguments: + chat_tts_stream: bool = field( + default=True, + metadata={ + "help": "The tts mode is stream Default is 'stream'." + }, + ) + chat_tts_device: str = field( + default="mps", + metadata={ + "help": "The device to be used for speech synthesis. Default is 'mps'." + }, + ) + chat_tts_chunk_size: int = field( + default=512, + metadata={ + "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。." + }, + ) diff --git a/requirements.txt b/requirements.txt index b4a5a0e36820f6b4a6900ada8d3a4b8a6059f669..a0f01c79042ad5fa1f4147881d3b3f57b7363799 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,7 @@ nltk==3.9.1 parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 -sounddevice==0.5.0 \ No newline at end of file +sounddevice==0.5.0 +funasr +modelscope +ChatTTS \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index c5d8e133be2d493eb6754f7532a5c779de31333c..dbd13193cb7fc386bab51042d88837e7156380a2 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -8,6 +8,7 @@ from threading import Event from typing import Optional from sys import platform from VAD.vad_handler import VADHandler +from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.mlx_language_model_arguments import ( MLXLanguageModelHandlerArguments, @@ -77,6 +78,7 @@ def main(): MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, + ChatTTSHandlerArguments ) ) @@ -93,6 +95,7 @@ def main(): mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, + chat_tts_handler_kwargs, ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: # Parse arguments from command line if no JSON file is provided @@ -106,6 +109,7 @@ def main(): mlx_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, + chat_tts_handler_kwargs, ) = parser.parse_args_into_dataclasses() # 1. Handle logger @@ -178,6 +182,7 @@ def main(): prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") + prepare_args(chat_tts_handler_kwargs,"chat_tts") # 3. Build the pipeline stop_event = Event() @@ -291,6 +296,21 @@ def main(): setup_args=(should_listen,), setup_kwargs=vars(melo_tts_handler_kwargs), ) + elif module_kwargs.tts == "chatTTS": + try: + from TTS.chatTTS_handler import ChatTTSHandler + except RuntimeError as e: + logger.error( + "Error importing ChatTTSHandler" + ) + raise e + tts = ChatTTSHandler( + stop_event, + queue_in=lm_response_queue, + queue_out=send_audio_chunks_queue, + setup_args=(should_listen,), + setup_kwargs=vars(chat_tts_handler_kwargs), + ) else: raise ValueError("The TTS should be either parler or melo")