diff --git a/STT/whisper_stt_handler.py b/STT/whisper_stt_handler.py index 1470bfbdbf0a307d4128a556190d3d95e84fd7e7..6aa165c7746ecc7a239125c4af1fe69c70ec8d63 100644 --- a/STT/whisper_stt_handler.py +++ b/STT/whisper_stt_handler.py @@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler): # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense warmup_gen_kwargs = { - "min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens + "min_new_tokens": self.gen_kwargs[ + "max_new_tokens" + ], # Yes, assign max_new_tokens to min_new_tokens "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } diff --git a/TTS/chatTTS_handler.py b/TTS/chatTTS_handler.py index 0c171ae2010f6ebf321081a3696c619229d41289..ee8ca25e9d6fb0a1245211558ed0618d491be264 100644 --- a/TTS/chatTTS_handler.py +++ b/TTS/chatTTS_handler.py @@ -18,15 +18,15 @@ class ChatTTSHandler(BaseHandler): def setup( self, should_listen, - device="mps", + device="cuda", gen_kwargs={}, # Unused stream=True, chunk_size=512, ): self.should_listen = should_listen self.device = device - self.model = ChatTTS.Chat() - self.model.load(compile=True) # Set to True for better performance + self.model = ChatTTS.Chat() + self.model.load(compile=False) # Doesn't work for me with True self.chunk_size = chunk_size self.stream = stream rnd_spk_emb = self.model.sample_random_speaker() @@ -37,8 +37,7 @@ class ChatTTSHandler(BaseHandler): def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") - _= self.model.infer("text") - + _ = self.model.infer("text") def process(self, llm_sentence): console.print(f"[green]ASSISTANT: {llm_sentence}") @@ -52,36 +51,32 @@ class ChatTTSHandler(BaseHandler): time.time() - start ) # Removing this line makes it fail more often. I'm looking into it. - wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream) + wavs_gen = self.model.infer( + llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream + ) if self.stream: wavs = [np.array([])] for gen in wavs_gen: - print('new chunk gen', len(gen[0])) if len(gen[0]) == 0: self.should_listen.set() return audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) - audio_chunk = (audio_chunk * 32768).astype(np.int16) - print('audio_chunk:', audio_chunk.shape) + audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] while len(audio_chunk) > self.chunk_size: - yield audio_chunk[:self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® - audio_chunk = audio_chunk[self.chunk_size:] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® - yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk))) + yield audio_chunk[: self.chunk_size] # è¿”å›žå‰ chunk_size å—èŠ‚çš„æ•°æ® + audio_chunk = audio_chunk[self.chunk_size :] # ç§»é™¤å·²è¿”å›žçš„æ•°æ® + yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) else: - print('check result', wavs_gen) wavs = wavs_gen if len(wavs[0]) == 0: self.should_listen.set() return audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) audio_chunk = (audio_chunk * 32768).astype(np.int16) - print('audio_chunk:', audio_chunk.shape) for i in range(0, len(audio_chunk), self.chunk_size): yield np.pad( audio_chunk[i : i + self.chunk_size], (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), ) self.should_listen.set() - - diff --git a/arguments_classes/chat_tts_arguments.py b/arguments_classes/chat_tts_arguments.py index 096c1a2a176e2e8e350983e4f28adb04e121e035..bccce27176a4e2e818a2285ebdfa2c2cd63d69c9 100644 --- a/arguments_classes/chat_tts_arguments.py +++ b/arguments_classes/chat_tts_arguments.py @@ -5,14 +5,12 @@ from dataclasses import dataclass, field class ChatTTSHandlerArguments: chat_tts_stream: bool = field( default=True, - metadata={ - "help": "The tts mode is stream Default is 'stream'." - }, + metadata={"help": "The tts mode is stream Default is 'stream'."}, ) chat_tts_device: str = field( - default="mps", + default="cuda", metadata={ - "help": "The device to be used for speech synthesis. Default is 'mps'." + "help": "The device to be used for speech synthesis. Default is 'cuda'." }, ) chat_tts_chunk_size: int = field( diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index df9d94286965d23a75e2f71d5594f99aeb9148fe..8bf4884e54c7a55a0ba783c9801d4a5b88026c56 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -35,7 +35,7 @@ class ModuleArguments: tts: Optional[str] = field( default="parler", metadata={ - "help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'" + "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" }, ) log_level: str = field( diff --git a/requirements.txt b/requirements.txt index 4acd623c08f850a6504b902c0263fc382d086f2b..78a37b65ef6ed088a02c52180fbc3a598ce19726 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers torch==2.4.0 sounddevice==0.5.0 -ChatTTS -funasr -modelscope \ No newline at end of file +ChatTTS>=0.1.1 +funasr>=1.1.6 +modelscope>=1.17.1 \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index 5e26fd46339d0dbcd8702954a0dc9a4be8cdec77..3bf9cb757d63ad88264f4aebe8bb6bb506eabe04 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -5,6 +5,6 @@ torch==2.4.0 sounddevice==0.5.0 lightning-whisper-mlx>=0.0.10 mlx-lm>=0.14.0 -ChatTTS +ChatTTS>=0.1.1 funasr>=1.1.6 modelscope>=1.17.1 diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 4138971e247f7c80e2ca4bd892cce9689bcd1ea9..8da829834e85c856458a571bf3c7242500d8ae6b 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -80,7 +80,7 @@ def main(): MLXLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, - ChatTTSHandlerArguments + ChatTTSHandlerArguments, ) ) @@ -190,7 +190,7 @@ def main(): prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(melo_tts_handler_kwargs, "melo") - prepare_args(chat_tts_handler_kwargs,"chat_tts") + prepare_args(chat_tts_handler_kwargs, "chat_tts") # 3. Build the pipeline stop_event = Event() @@ -319,9 +319,7 @@ def main(): try: from TTS.chatTTS_handler import ChatTTSHandler except RuntimeError as e: - logger.error( - "Error importing ChatTTSHandler" - ) + logger.error("Error importing ChatTTSHandler") raise e tts = ChatTTSHandler( stop_event, @@ -331,7 +329,7 @@ def main(): setup_kwargs=vars(chat_tts_handler_kwargs), ) else: - raise ValueError("The TTS should be either parler or melo") + raise ValueError("The TTS should be either parler, melo or chatTTS") # 4. Run the pipeline try: