Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Speech To Speech
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
huggingface
Speech To Speech
Commits
9941f6f3
Commit
9941f6f3
authored
6 months ago
by
Andres Marafioti
Browse files
Options
Downloads
Patches
Plain Diff
refactor s2s_pipeline
parent
0ae1b01d
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
s2s_pipeline.py
+142
-96
142 additions, 96 deletions
s2s_pipeline.py
with
142 additions
and
96 deletions
s2s_pipeline.py
+
142
−
96
View file @
9941f6f3
...
...
@@ -67,7 +67,7 @@ def prepare_args(args, prefix):
args
.
__dict__
[
"
gen_kwargs
"
]
=
gen_kwargs
def
main
():
def
parse_arguments
():
parser
=
HfArgumentParser
(
(
ModuleArguments
,
...
...
@@ -84,69 +84,40 @@ def main():
)
)
# 0. Parse CLI arguments
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
"
.json
"
):
# Parse configurations from a JSON file if specified
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
return
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
else
:
# Parse arguments from command line if no JSON file is provided
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parser
.
parse_args_into_dataclasses
()
# 1. Handle logger
return
parser
.
parse_args_into_dataclasses
()
def
setup_logger
(
log_level
):
global
logger
logging
.
basicConfig
(
level
=
module_kwargs
.
log_level
.
upper
(),
level
=
log_level
.
upper
(),
format
=
"
%(asctime)s - %(name)s - %(levelname)s - %(message)s
"
,
)
logger
=
logging
.
getLogger
(
__name__
)
# torch compile logs
if
module_kwargs
.
log_level
==
"
debug
"
:
if
log_level
==
"
debug
"
:
torch
.
_logging
.
set_logs
(
graph_breaks
=
True
,
recompiles
=
True
,
cudagraphs
=
True
)
def
optimal_mac_settings
(
mac_optimal_settings
:
Optional
[
str
],
*
handler_kwargs
):
if
mac_optimal_settings
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
device
"
):
kwargs
.
device
=
"
mps
"
if
hasattr
(
kwargs
,
"
mode
"
):
kwargs
.
mode
=
"
local
"
if
hasattr
(
kwargs
,
"
stt
"
):
kwargs
.
stt
=
"
whisper-mlx
"
if
hasattr
(
kwargs
,
"
llm
"
):
kwargs
.
llm
=
"
mlx-lm
"
if
hasattr
(
kwargs
,
"
tts
"
):
kwargs
.
tts
=
"
melo
"
optimal_mac_settings
(
module_kwargs
.
local_mac_optimal_settings
,
module_kwargs
,
)
def
optimal_mac_settings
(
mac_optimal_settings
:
Optional
[
str
],
*
handler_kwargs
):
if
mac_optimal_settings
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
device
"
):
kwargs
.
device
=
"
mps
"
if
hasattr
(
kwargs
,
"
mode
"
):
kwargs
.
mode
=
"
local
"
if
hasattr
(
kwargs
,
"
stt
"
):
kwargs
.
stt
=
"
whisper-mlx
"
if
hasattr
(
kwargs
,
"
llm
"
):
kwargs
.
llm
=
"
mlx-lm
"
if
hasattr
(
kwargs
,
"
tts
"
):
kwargs
.
tts
=
"
melo
"
def
check_mac_settings
(
module_kwargs
):
if
platform
==
"
darwin
"
:
if
module_kwargs
.
device
==
"
cuda
"
:
raise
ValueError
(
...
...
@@ -161,29 +132,29 @@ def main():
"
If you experiences issues generating the voice, considering setting the tts to melo.
"
)
# 2. Prepare each part's arguments
def
overwrite_device_argument
(
common_device
:
Optional
[
str
],
*
handler_kwargs
):
if
common_device
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
lm_device
"
):
kwargs
.
lm_device
=
common_device
if
hasattr
(
kwargs
,
"
tts_device
"
):
kwargs
.
tts_device
=
common_device
if
hasattr
(
kwargs
,
"
stt_device
"
):
kwargs
.
stt_device
=
common_device
if
hasattr
(
kwargs
,
"
paraformer_stt_device
"
):
kwargs
.
paraformer_stt_device
=
common_device
# Call this function with the common device and all the handlers
overwrite_device_argument
(
module_kwargs
.
device
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
)
def
overwrite_device_argument
(
common_device
:
Optional
[
str
],
*
handler_kwargs
):
if
common_device
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
lm_device
"
):
kwargs
.
lm_device
=
common_device
if
hasattr
(
kwargs
,
"
tts_device
"
):
kwargs
.
tts_device
=
common_device
if
hasattr
(
kwargs
,
"
stt_device
"
):
kwargs
.
stt_device
=
common_device
if
hasattr
(
kwargs
,
"
paraformer_stt_device
"
):
kwargs
.
paraformer_stt_device
=
common_device
def
prepare_all_args
(
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
):
prepare_args
(
whisper_stt_handler_kwargs
,
"
stt
"
)
prepare_args
(
paraformer_stt_handler_kwargs
,
"
paraformer_stt
"
)
prepare_args
(
language_model_handler_kwargs
,
"
lm
"
)
...
...
@@ -192,7 +163,20 @@ def main():
prepare_args
(
melo_tts_handler_kwargs
,
"
melo
"
)
prepare_args
(
chat_tts_handler_kwargs
,
"
chat_tts
"
)
# 3. Build the pipeline
def
build_pipeline
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
):
stop_event
=
Event
()
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS
should_listen
=
Event
()
...
...
@@ -238,10 +222,18 @@ def main():
setup_args
=
(
should_listen
,),
setup_kwargs
=
vars
(
vad_handler_kwargs
),
)
stt
=
get_stt_handler
(
module_kwargs
,
stop_event
,
spoken_prompt_queue
,
text_prompt_queue
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
)
lm
=
get_llm_handler
(
module_kwargs
,
stop_event
,
text_prompt_queue
,
lm_response_queue
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
)
tts
=
get_tts_handler
(
module_kwargs
,
stop_event
,
lm_response_queue
,
send_audio_chunks_queue
,
should_listen
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
)
return
ThreadManager
([
*
comms_handlers
,
vad
,
stt
,
lm
,
tts
])
def
get_stt_handler
(
module_kwargs
,
stop_event
,
spoken_prompt_queue
,
text_prompt_queue
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
):
if
module_kwargs
.
stt
==
"
whisper
"
:
from
STT.whisper_stt_handler
import
WhisperSTTHandler
stt
=
WhisperSTTHandler
(
return
WhisperSTTHandler
(
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
...
...
@@ -249,8 +241,7 @@ def main():
)
elif
module_kwargs
.
stt
==
"
whisper-mlx
"
:
from
STT.lightning_whisper_mlx_handler
import
LightningWhisperSTTHandler
stt
=
LightningWhisperSTTHandler
(
return
LightningWhisperSTTHandler
(
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
...
...
@@ -258,21 +249,20 @@ def main():
)
elif
module_kwargs
.
stt
==
"
paraformer
"
:
from
STT.paraformer_handler
import
ParaformerSTTHandler
stt
=
ParaformerSTTHandler
(
return
ParaformerSTTHandler
(
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
setup_kwargs
=
vars
(
paraformer_stt_handler_kwargs
),
)
else
:
raise
ValueError
(
"
The STT should be either whisper, whisper-mlx, or paraformer.
"
)
raise
ValueError
(
"
The STT should be either whisper, whisper-mlx, or paraformer.
"
)
def
get_llm_handler
(
module_kwargs
,
stop_event
,
text_prompt_queue
,
lm_response_queue
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
):
if
module_kwargs
.
llm
==
"
transformers
"
:
from
LLM.language_model
import
LanguageModelHandler
lm
=
LanguageModelHandler
(
return
LanguageModelHandler
(
stop_event
,
queue_in
=
text_prompt_queue
,
queue_out
=
lm_response_queue
,
...
...
@@ -280,8 +270,7 @@ def main():
)
elif
module_kwargs
.
llm
==
"
mlx-lm
"
:
from
LLM.mlx_language_model
import
MLXLanguageModelHandler
lm
=
MLXLanguageModelHandler
(
return
MLXLanguageModelHandler
(
stop_event
,
queue_in
=
text_prompt_queue
,
queue_out
=
lm_response_queue
,
...
...
@@ -289,10 +278,12 @@ def main():
)
else
:
raise
ValueError
(
"
The LLM should be either transformers or mlx-lm
"
)
def
get_tts_handler
(
module_kwargs
,
stop_event
,
lm_response_queue
,
send_audio_chunks_queue
,
should_listen
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
):
if
module_kwargs
.
tts
==
"
parler
"
:
from
TTS.parler_handler
import
ParlerTTSHandler
tts
=
ParlerTTSHandler
(
return
ParlerTTSHandler
(
stop_event
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
...
...
@@ -307,7 +298,7 @@ def main():
"
Error importing MeloTTSHandler. You might need to run: python -m unidic download
"
)
raise
e
tts
=
MeloTTSHandler
(
return
MeloTTSHandler
(
stop_event
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
...
...
@@ -320,7 +311,7 @@ def main():
except
RuntimeError
as
e
:
logger
.
error
(
"
Error importing ChatTTSHandler
"
)
raise
e
tts
=
ChatTTSHandler
(
return
ChatTTSHandler
(
stop_event
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
...
...
@@ -330,14 +321,69 @@ def main():
else
:
raise
ValueError
(
"
The TTS should be either parler, melo or chatTTS
"
)
# 4. Run the pipeline
def
main
():
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parse_arguments
()
setup_logger
(
module_kwargs
.
log_level
)
optimal_mac_settings
(
module_kwargs
.
local_mac_optimal_settings
,
module_kwargs
,
)
check_mac_settings
(
module_kwargs
)
overwrite_device_argument
(
module_kwargs
.
device
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
)
prepare_all_args
(
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
pipeline_manager
=
build_pipeline
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
try
:
pipeline_manager
=
ThreadManager
([
*
comms_handlers
,
vad
,
stt
,
lm
,
tts
])
pipeline_manager
.
start
()
except
KeyboardInterrupt
:
pipeline_manager
.
stop
()
if
__name__
==
"
__main__
"
:
main
()
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment