Skip to content
Snippets Groups Projects
Unverified Commit 8bfbb8df authored by wuhongsheng's avatar wuhongsheng Committed by GitHub
Browse files

Merge branch 'main' into DeepFilterNet

parents 363a74de 712afa1c
No related branches found
No related tags found
No related merge requests found
......@@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler):
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens
"min_new_tokens": self.gen_kwargs[
"max_new_tokens"
], # Yes, assign max_new_tokens to min_new_tokens
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
......
import ChatTTS
import logging
from baseHandler import BaseHandler
import librosa
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ChatTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="cuda",
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
):
self.should_listen = should_listen
self.device = device
self.model = ChatTTS.Chat()
self.model.load(compile=False) # Doesn't work for me with True
self.chunk_size = chunk_size
self.stream = stream
rnd_spk_emb = self.model.sample_random_speaker()
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rnd_spk_emb,
)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
_ = self.model.infer("text")
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.device == "mps":
import time
start = time.time()
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
_ = (
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
wavs_gen = self.model.infer(
llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
)
if self.stream:
wavs = [np.array([])]
for gen in wavs_gen:
if gen[0] is None or len(gen[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
else:
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
self.should_listen.set()
from dataclasses import dataclass, field
@dataclass
class ChatTTSHandlerArguments:
chat_tts_stream: bool = field(
default=True,
metadata={"help": "The tts mode is stream Default is 'stream'."},
)
chat_tts_device: str = field(
default="cuda",
metadata={
"help": "The device to be used for speech synthesis. Default is 'cuda'."
},
)
chat_tts_chunk_size: int = field(
default=512,
metadata={
"help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
},
)
......@@ -35,7 +35,7 @@ class ModuleArguments:
tts: Optional[str] = field(
default="parler",
metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
"help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
},
)
log_level: str = field(
......
......@@ -3,6 +3,7 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0
sounddevice==0.5.0
funasr
modelscope
deepfilternet
ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1
deepfilternet>=0.5.6
......@@ -5,6 +5,8 @@ torch==2.4.0
sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0
ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1
deepfilternet
deepfilternet>=0.5.6
......@@ -8,6 +8,7 @@ from threading import Event
from typing import Optional
from sys import platform
from VAD.vad_handler import VADHandler
from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments,
......@@ -79,6 +80,7 @@ def main():
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
MeloTTSHandlerArguments,
ChatTTSHandlerArguments,
)
)
......@@ -96,6 +98,7 @@ def main():
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
# Parse arguments from command line if no JSON file is provided
......@@ -110,6 +113,7 @@ def main():
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
......@@ -186,6 +190,7 @@ def main():
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline
stop_event = Event()
......@@ -310,8 +315,21 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(melo_tts_handler_kwargs),
)
elif module_kwargs.tts == "chatTTS":
try:
from TTS.chatTTS_handler import ChatTTSHandler
except RuntimeError as e:
logger.error("Error importing ChatTTSHandler")
raise e
tts = ChatTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(chat_tts_handler_kwargs),
)
else:
raise ValueError("The TTS should be either parler or melo")
raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment