Skip to content
Snippets Groups Projects
Unverified Commit 7c99fd7f authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge branch 'add-chattts' into chatTTS3

parents 54194b35 7978683c
Branches
Tags
No related merge requests found
import logging
from time import perf_counter
from baseHandler import BaseHandler
from funasr import AutoModel
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ParaformerSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Paraformer model.
The default for this model is set to Chinese.
This model was contributed by @wuhongsheng.
"""
def setup(
self,
model_name="paraformer-zh",
device="cuda",
gen_kwargs={},
):
print(model_name)
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = AutoModel(model=model_name, device=device)
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1
dummy_input = np.array([0] * 512, dtype=np.float32)
for _ in range(n_steps):
_ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
def process(self, spoken_prompt):
logger.debug("infering paraformer...")
global pipeline_start
pipeline_start = perf_counter()
pred_text = (
self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
)
torch.mps.empty_cache()
logger.debug("finished paraformer inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text
......@@ -66,8 +66,9 @@ class WhisperSTTHandler(BaseHandler):
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
"min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
......
......@@ -23,7 +23,7 @@ class ModuleArguments:
stt: Optional[str] = field(
default="whisper",
metadata={
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
"help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'."
},
)
llm: Optional[str] = field(
......
from dataclasses import dataclass, field
@dataclass
class ParaformerSTTHandlerArguments:
paraformer_stt_model_name: str = field(
default="paraformer-zh",
metadata={
"help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR"
},
)
paraformer_stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
......@@ -33,12 +33,6 @@ class WhisperSTTHandlerArguments:
"help": "The maximum number of new tokens to generate. Default is 128."
},
)
stt_gen_min_new_tokens: int = field(
default=0,
metadata={
"help": "The minimum number of new tokens to generate. Default is 0."
},
)
stt_gen_num_beams: int = field(
default=1,
metadata={
......
......@@ -3,4 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0
sounddevice==0.5.0
ChatTTS
\ No newline at end of file
ChatTTS
funasr
modelscope
\ No newline at end of file
......@@ -5,4 +5,6 @@ torch==2.4.0
sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0
ChatTTS
\ No newline at end of file
ChatTTS
funasr>=1.1.6
modelscope>=1.17.1
......@@ -14,6 +14,7 @@ from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments,
)
from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
from arguments_classes.socket_sender_arguments import SocketSenderArguments
......@@ -74,6 +75,7 @@ def main():
SocketSenderArguments,
VADHandlerArguments,
WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
......@@ -91,6 +93,7 @@ def main():
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
......@@ -105,6 +108,7 @@ def main():
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
......@@ -167,6 +171,8 @@ def main():
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
# Call this function with the common device and all the handlers
overwrite_device_argument(
......@@ -175,9 +181,11 @@ def main():
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
)
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
......@@ -248,8 +256,19 @@ def main():
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler
stt = ParaformerSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(paraformer_stt_handler_kwargs),
)
else:
raise ValueError("The STT should be either whisper or whisper-mlx")
raise ValueError(
"The STT should be either whisper, whisper-mlx, or paraformer."
)
if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment