Skip to content
Snippets Groups Projects
Commit 71f25443 authored by wuhongsheng's avatar wuhongsheng
Browse files

feat:add chatTTS

parent fc9f9602
No related branches found
No related tags found
No related merge requests found
import ChatTTS
import logging
from baseHandler import BaseHandler
import librosa
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ChatTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
):
self.should_listen = should_listen
self.device = device
self.model = ChatTTS.Chat()
self.model.load(compile=True) # Set to True for better performance
self.chunk_size = chunk_size
self.stream = stream
rnd_spk_emb = self.model.sample_random_speaker()
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rnd_spk_emb,
)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
_= self.model.infer("text")
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.device == "mps":
import time
start = time.time()
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
_ = (
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream)
if self.stream:
wavs = [np.array([])]
for gen in wavs_gen:
print('new chunk gen', len(gen[0]))
if len(gen[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[:self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size:] # 移除已返回的数据
yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk)))
else:
print('check result', wavs_gen)
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
self.should_listen.set()
from dataclasses import dataclass, field
@dataclass
class ChatTTSHandlerArguments:
chat_tts_stream: bool = field(
default=True,
metadata={
"help": "The tts mode is stream Default is 'stream'."
},
)
chat_tts_device: str = field(
default="mps",
metadata={
"help": "The device to be used for speech synthesis. Default is 'mps'."
},
)
chat_tts_chunk_size: int = field(
default=512,
metadata={
"help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
},
)
...@@ -2,4 +2,7 @@ nltk==3.9.1 ...@@ -2,4 +2,7 @@ nltk==3.9.1
parler_tts @ git+https://github.com/huggingface/parler-tts.git parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0 torch==2.4.0
sounddevice==0.5.0 sounddevice==0.5.0
\ No newline at end of file funasr
modelscope
ChatTTS
\ No newline at end of file
...@@ -8,6 +8,7 @@ from threading import Event ...@@ -8,6 +8,7 @@ from threading import Event
from typing import Optional from typing import Optional
from sys import platform from sys import platform
from VAD.vad_handler import VADHandler from VAD.vad_handler import VADHandler
from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import ( from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
...@@ -77,6 +78,7 @@ def main(): ...@@ -77,6 +78,7 @@ def main():
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
MeloTTSHandlerArguments, MeloTTSHandlerArguments,
ChatTTSHandlerArguments
) )
) )
...@@ -93,6 +95,7 @@ def main(): ...@@ -93,6 +95,7 @@ def main():
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else: else:
# Parse arguments from command line if no JSON file is provided # Parse arguments from command line if no JSON file is provided
...@@ -106,6 +109,7 @@ def main(): ...@@ -106,6 +109,7 @@ def main():
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
melo_tts_handler_kwargs, melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses() ) = parser.parse_args_into_dataclasses()
# 1. Handle logger # 1. Handle logger
...@@ -178,6 +182,7 @@ def main(): ...@@ -178,6 +182,7 @@ def main():
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo") prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs,"chat_tts")
# 3. Build the pipeline # 3. Build the pipeline
stop_event = Event() stop_event = Event()
...@@ -291,6 +296,21 @@ def main(): ...@@ -291,6 +296,21 @@ def main():
setup_args=(should_listen,), setup_args=(should_listen,),
setup_kwargs=vars(melo_tts_handler_kwargs), setup_kwargs=vars(melo_tts_handler_kwargs),
) )
elif module_kwargs.tts == "chatTTS":
try:
from TTS.chatTTS_handler import ChatTTSHandler
except RuntimeError as e:
logger.error(
"Error importing ChatTTSHandler"
)
raise e
tts = ChatTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(chat_tts_handler_kwargs),
)
else: else:
raise ValueError("The TTS should be either parler or melo") raise ValueError("The TTS should be either parler or melo")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment