Skip to content
Snippets Groups Projects
Commit 07ea8e75 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

few changes to improve the paraformer addition

parent ef9ce5b3
No related branches found
No related tags found
No related merge requests found
import logging import logging
from time import perf_counter from time import perf_counter
from tensorstore import dtype
from baseHandler import BaseHandler from baseHandler import BaseHandler
from funasr import AutoModel from funasr import AutoModel
import numpy as np import numpy as np
...@@ -19,22 +17,22 @@ console = Console() ...@@ -19,22 +17,22 @@ console = Console()
class ParaformerSTTHandler(BaseHandler): class ParaformerSTTHandler(BaseHandler):
""" """
Handles the Speech To Text generation using a Whisper model. Handles the Speech To Text generation using a Paraformer model.
The default for this model is set to Chinese.
This model was contributed by @wuhongsheng.
""" """
def setup( def setup(
self, self,
model_name="paraformer-zh", model_name="paraformer-zh",
device="cuda", device="cuda",
torch_dtype="float32",
compile_mode=None,
gen_kwargs={}, gen_kwargs={},
): ):
print(model_name) print(model_name)
if len(model_name.split("/")) > 1: if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1] model_name = model_name.split("/")[-1]
self.device = device self.device = device
self.model = AutoModel(model=model_name) self.model = AutoModel(model=model_name, device=device)
self.warmup() self.warmup()
def warmup(self): def warmup(self):
...@@ -42,9 +40,9 @@ class ParaformerSTTHandler(BaseHandler): ...@@ -42,9 +40,9 @@ class ParaformerSTTHandler(BaseHandler):
# 2 warmup steps for no compile or compile mode with CUDA graphs capture # 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 n_steps = 1
dummy_input = np.array([0] * 512,dtype=np.float32) dummy_input = np.array([0] * 512, dtype=np.float32)
for _ in range(n_steps): for _ in range(n_steps):
_ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ","") _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
def process(self, spoken_prompt): def process(self, spoken_prompt):
logger.debug("infering paraformer...") logger.debug("infering paraformer...")
...@@ -52,7 +50,9 @@ class ParaformerSTTHandler(BaseHandler): ...@@ -52,7 +50,9 @@ class ParaformerSTTHandler(BaseHandler):
global pipeline_start global pipeline_start
pipeline_start = perf_counter() pipeline_start = perf_counter()
pred_text = self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ","") pred_text = (
self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
)
torch.mps.empty_cache() torch.mps.empty_cache()
logger.debug("finished paraformer inference") logger.debug("finished paraformer inference")
......
...@@ -23,7 +23,7 @@ class ModuleArguments: ...@@ -23,7 +23,7 @@ class ModuleArguments:
stt: Optional[str] = field( stt: Optional[str] = field(
default="whisper", default="whisper",
metadata={ metadata={
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'." "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'."
}, },
) )
llm: Optional[str] = field( llm: Optional[str] = field(
......
from dataclasses import dataclass, field
@dataclass
class ParaformerSTTHandlerArguments:
paraformer_stt_model_name: str = field(
default="paraformer-zh",
metadata={
"help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR"
},
)
paraformer_stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
...@@ -5,5 +5,5 @@ torch==2.4.0 ...@@ -5,5 +5,5 @@ torch==2.4.0
sounddevice==0.5.0 sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10 lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0 mlx-lm>=0.14.0
funasr funasr>=1.1.6
modelscope modelscope>=1.17.1
\ No newline at end of file \ No newline at end of file
...@@ -13,6 +13,7 @@ from arguments_classes.mlx_language_model_arguments import ( ...@@ -13,6 +13,7 @@ from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
) )
from arguments_classes.module_arguments import ModuleArguments from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
from arguments_classes.socket_sender_arguments import SocketSenderArguments from arguments_classes.socket_sender_arguments import SocketSenderArguments
...@@ -73,6 +74,7 @@ def main(): ...@@ -73,6 +74,7 @@ def main():
SocketSenderArguments, SocketSenderArguments,
VADHandlerArguments, VADHandlerArguments,
WhisperSTTHandlerArguments, WhisperSTTHandlerArguments,
ParaformerSTTHandlerArguments,
LanguageModelHandlerArguments, LanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments, ParlerTTSHandlerArguments,
...@@ -89,6 +91,7 @@ def main(): ...@@ -89,6 +91,7 @@ def main():
socket_sender_kwargs, socket_sender_kwargs,
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -102,6 +105,7 @@ def main(): ...@@ -102,6 +105,7 @@ def main():
socket_sender_kwargs, socket_sender_kwargs,
vad_handler_kwargs, vad_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs, language_model_handler_kwargs,
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
...@@ -163,6 +167,8 @@ def main(): ...@@ -163,6 +167,8 @@ def main():
kwargs.tts_device = common_device kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"): if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
# Call this function with the common device and all the handlers # Call this function with the common device and all the handlers
overwrite_device_argument( overwrite_device_argument(
...@@ -171,9 +177,11 @@ def main(): ...@@ -171,9 +177,11 @@ def main():
mlx_language_model_handler_kwargs, mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
whisper_stt_handler_kwargs, whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
) )
prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm") prepare_args(language_model_handler_kwargs, "lm")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm") prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
...@@ -245,14 +253,17 @@ def main(): ...@@ -245,14 +253,17 @@ def main():
) )
elif module_kwargs.stt == "paraformer": elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler from STT.paraformer_handler import ParaformerSTTHandler
stt = ParaformerSTTHandler( stt = ParaformerSTTHandler(
stop_event, stop_event,
queue_in=spoken_prompt_queue, queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue, queue_out=text_prompt_queue,
# setup_kwargs=vars(whisper_stt_handler_kwargs), setup_kwargs=vars(paraformer_stt_handler_kwargs),
) )
else: else:
raise ValueError("The STT should be either whisper or whisper-mlx") raise ValueError(
"The STT should be either whisper, whisper-mlx, or paraformer."
)
if module_kwargs.llm == "transformers": if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler from LLM.language_model import LanguageModelHandler
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment