From a26395c5f392b03d661e9546be60570666a35418 Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Wed, 28 Aug 2024 15:43:33 +0200 Subject: [PATCH] linting and few fixes for chattts --- STT/whisper_stt_handler.py | 4 +++- TTS/chatTTS_handler.py | 27 ++++++++++--------------- arguments_classes/chat_tts_arguments.py | 8 +++----- arguments_classes/module_arguments.py | 2 +- requirements.txt | 6 +++--- requirements_mac.txt | 2 +- s2s_pipeline.py | 10 ++++----- 7 files changed, 26 insertions(+), 33 deletions(-) diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 1470bfb..6aa165c 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler): # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens + "min_new_tokens": self.gen_kwargs[ + "max_new_tokens" + ], # Yes, assign max_new_tokens to min_new_tokens "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py index 0c171ae..ee8ca25 100644 --- a/TTS/chatTTS_handler.py +++ b/TTS/chatTTS_handler.py @@ -18,15 +18,15 @@ class ChatTTSHandler(BaseHandler): def setup( self, should_listen, - device="mps", + device="cuda", gen_kwargs={}, # Unused stream=True, chunk_size=512, ): self.should_listen = should_listen self.device = device - self.model = ChatTTS.Chat() - self.model.load(compile=True) # Set to True for better performance + self.model = ChatTTS.Chat() + self.model.load(compile=False) # Doesn't work for me with True self.chunk_size = chunk_size self.stream = stream rnd_spk_emb = self.model.sample_random_speaker() @@ -37,8 +37,7 @@ class ChatTTSHandler(BaseHandler): def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") - _= self.model.infer("text") - + _ = self.model.infer("text") def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") @@ -52,36 +51,32 @@ class ChatTTSHandler(BaseHandler): time.time() - start ) # Removing this line makes it fail more often. I'm looking into it. - wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream) + wavs_gen = self.model.infer( + llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream + ) if self.stream: wavs = [np.array([])] for gen in wavs_gen: - print('new chunk gen', len(gen[0])) if len(gen[0]) == 0: self.should_listen.set() return audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) - audio_chunk = (audio_chunk * 32768).astype(np.int16) - print('audio_chunk:', audio_chunk.shape) + audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] while len(audio_chunk) > self.chunk_size: - yield audio_chunk[:self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® - audio_chunk = audio_chunk[self.chunk_size:] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® - yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk))) + yield audio_chunk[: self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® + audio_chunk = audio_chunk[self.chunk_size :] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® + yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) else: - print('check result', wavs_gen) wavs = wavs_gen if len(wavs[0]) == 0: self.should_listen.set() return audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) audio_chunk = (audio_chunk * 32768).astype(np.int16) - print('audio_chunk:', audio_chunk.shape) for i in range(0, len(audio_chunk), self.chunk_size): yield np.pad( audio_chunk[i : i + self.chunk_size], (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), ) self.should_listen.set() - - diff --git a/arguments_classes/chat_tts_arguments.py b/arguments_classes/chat_tts_arguments.py index 096c1a2..bccce27 100644 --- a/arguments_classes/chat_tts_arguments.py +++ b/arguments_classes/chat_tts_arguments.py @@ -5,14 +5,12 @@ from dataclasses import dataclass, field class ChatTTSHandlerArguments: chat_tts_stream: bool = field( default=True, - metadata={ - "help": "The tts mode is stream Default is 'stream'." - }, + metadata={"help": "The tts mode is stream Default is 'stream'."}, ) chat_tts_device: str = field( - default="mps", + default="cuda", metadata={ - "help": "The device to be used for speech synthesis. Default is 'mps'." + "help": "The device to be used for speech synthesis. Default is 'cuda'." }, ) chat_tts_chunk_size: int = field( diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index df9d942..8bf4884 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -35,7 +35,7 @@ class ModuleArguments: tts: Optional[str] = field( default="parler", metadata={ - "help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" + "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" }, ) log_level: str = field( diff --git a/requirements.txt b/requirements.txt index 4acd623..78a37b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 sounddevice==0.5.0 -ChatTTS -funasr -modelscope \ No newline at end of file +ChatTTS>=0.1.1 +funasr>=1.1.6 +modelscope>=1.17.1 \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index 5e26fd4..3bf9cb7 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -5,6 +5,6 @@ torch==2.4.0 sounddevice==0.5.0 lightning-whisper-mlx>=0.0.10 mlx-lm>=0.14.0 -ChatTTS +ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 4138971..8da8298 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -80,7 +80,7 @@ def main(): MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, - ChatTTSHandlerArguments + ChatTTSHandlerArguments, ) ) @@ -190,7 +190,7 @@ def main(): prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") - prepare_args(chat_tts_handler_kwargs,"chat_tts") + prepare_args(chat_tts_handler_kwargs, "chat_tts") # 3. Build the pipeline stop_event = Event() @@ -319,9 +319,7 @@ def main(): try: from TTS.chatTTS_handler import ChatTTSHandler except RuntimeError as e: - logger.error( - "Error importing ChatTTSHandler" - ) + logger.error("Error importing ChatTTSHandler") raise e tts = ChatTTSHandler( stop_event, @@ -331,7 +329,7 @@ def main(): setup_kwargs=vars(chat_tts_handler_kwargs), ) else: - raise ValueError("The TTS should be either parler or melo") + raise ValueError("The TTS should be either parler, melo or chatTTS") # 4. Run the pipeline try: -- GitLab