Skip to content
Snippets Groups Projects
Unverified Commit c40bf05d authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge pull request #55 from wuhongsheng/chatTTS3

feat:add chatTTS
parents 7978683c 7c99fd7f
No related branches found
No related tags found
No related merge requests found
import ChatTTS
import logging
from baseHandler import BaseHandler
import librosa
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ChatTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
):
self.should_listen = should_listen
self.device = device
self.model = ChatTTS.Chat()
self.model.load(compile=True) # Set to True for better performance
self.chunk_size = chunk_size
self.stream = stream
rnd_spk_emb = self.model.sample_random_speaker()
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rnd_spk_emb,
)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
_= self.model.infer("text")
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.device == "mps":
import time
start = time.time()
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
_ = (
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream)
if self.stream:
wavs = [np.array([])]
for gen in wavs_gen:
print('new chunk gen', len(gen[0]))
if len(gen[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[:self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size:] # 移除已返回的数据
yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk)))
else:
print('check result', wavs_gen)
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
self.should_listen.set()
from dataclasses import dataclass, field
@dataclass
class ChatTTSHandlerArguments:
chat_tts_stream: bool = field(
default=True,
metadata={
"help": "The tts mode is stream Default is 'stream'."
},
)
chat_tts_device: str = field(
default="mps",
metadata={
"help": "The device to be used for speech synthesis. Default is 'mps'."
},
)
chat_tts_chunk_size: int = field(
default=512,
metadata={
"help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
},
)
......@@ -3,5 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0
sounddevice==0.5.0
ChatTTS
funasr
modelscope
\ No newline at end of file
......@@ -5,5 +5,6 @@ torch==2.4.0
sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0
ChatTTS
funasr>=1.1.6
modelscope>=1.17.1
\ No newline at end of file
modelscope>=1.17.1
......@@ -8,6 +8,7 @@ from threading import Event
from typing import Optional
from sys import platform
from VAD.vad_handler import VADHandler
from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments,
......@@ -79,6 +80,7 @@ def main():
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
MeloTTSHandlerArguments,
ChatTTSHandlerArguments
)
)
......@@ -96,6 +98,7 @@ def main():
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
# Parse arguments from command line if no JSON file is provided
......@@ -110,6 +113,7 @@ def main():
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
......@@ -186,6 +190,7 @@ def main():
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs,"chat_tts")
# 3. Build the pipeline
stop_event = Event()
......@@ -310,6 +315,21 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(melo_tts_handler_kwargs),
)
elif module_kwargs.tts == "chatTTS":
try:
from TTS.chatTTS_handler import ChatTTSHandler
except RuntimeError as e:
logger.error(
"Error importing ChatTTSHandler"
)
raise e
tts = ChatTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(chat_tts_handler_kwargs),
)
else:
raise ValueError("The TTS should be either parler or melo")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment