Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Speech To Speech
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
huggingface
Speech To Speech
Commits
8c7272b7
Unverified
Commit
8c7272b7
authored
6 months ago
by
eustlb
Committed by
GitHub
6 months ago
Browse files
Options
Downloads
Plain Diff
Merge pull request #106 from huggingface/refactor_for_inference
Refactor for inference
parents
0ae1b01d
0bc30b68
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
s2s_pipeline.py
+175
-108
175 additions, 108 deletions
s2s_pipeline.py
with
175 additions
and
108 deletions
s2s_pipeline.py
+
175
−
108
View file @
8c7272b7
...
@@ -49,11 +49,10 @@ console = Console()
...
@@ -49,11 +49,10 @@ console = Console()
logging
.
getLogger
(
"
numba
"
).
setLevel
(
logging
.
WARNING
)
# quiet down numba logs
logging
.
getLogger
(
"
numba
"
).
setLevel
(
logging
.
WARNING
)
# quiet down numba logs
def
p
re
par
e_args
(
args
,
prefix
):
def
re
nam
e_args
(
args
,
prefix
):
"""
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
Rename arguments by removing the prefix and prepares the gen_kwargs.
"""
"""
gen_kwargs
=
{}
gen_kwargs
=
{}
for
key
in
copy
(
args
.
__dict__
):
for
key
in
copy
(
args
.
__dict__
):
if
key
.
startswith
(
prefix
):
if
key
.
startswith
(
prefix
):
...
@@ -67,7 +66,7 @@ def prepare_args(args, prefix):
...
@@ -67,7 +66,7 @@ def prepare_args(args, prefix):
args
.
__dict__
[
"
gen_kwargs
"
]
=
gen_kwargs
args
.
__dict__
[
"
gen_kwargs
"
]
=
gen_kwargs
def
main
():
def
parse_arguments
():
parser
=
HfArgumentParser
(
parser
=
HfArgumentParser
(
(
(
ModuleArguments
,
ModuleArguments
,
...
@@ -84,69 +83,43 @@ def main():
...
@@ -84,69 +83,43 @@ def main():
)
)
)
)
# 0. Parse CLI arguments
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
"
.json
"
):
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
"
.json
"
):
# Parse configurations from a JSON file if specified
# Parse configurations from a JSON file if specified
(
return
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
else
:
else
:
# Parse arguments from command line if no JSON file is provided
# Parse arguments from command line if no JSON file is provided
(
return
parser
.
parse_args_into_dataclasses
()
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
def
setup_logger
(
log_level
):
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parser
.
parse_args_into_dataclasses
()
# 1. Handle logger
global
logger
global
logger
logging
.
basicConfig
(
logging
.
basicConfig
(
level
=
module_kwargs
.
log_level
.
upper
(),
level
=
log_level
.
upper
(),
format
=
"
%(asctime)s - %(name)s - %(levelname)s - %(message)s
"
,
format
=
"
%(asctime)s - %(name)s - %(levelname)s - %(message)s
"
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# torch compile logs
# torch compile logs
if
module_kwargs
.
log_level
==
"
debug
"
:
if
log_level
==
"
debug
"
:
torch
.
_logging
.
set_logs
(
graph_breaks
=
True
,
recompiles
=
True
,
cudagraphs
=
True
)
torch
.
_logging
.
set_logs
(
graph_breaks
=
True
,
recompiles
=
True
,
cudagraphs
=
True
)
def
optimal_mac_settings
(
mac_optimal_settings
:
Optional
[
str
],
*
handler_kwargs
):
if
mac_optimal_settings
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
device
"
):
kwargs
.
device
=
"
mps
"
if
hasattr
(
kwargs
,
"
mode
"
):
kwargs
.
mode
=
"
local
"
if
hasattr
(
kwargs
,
"
stt
"
):
kwargs
.
stt
=
"
whisper-mlx
"
if
hasattr
(
kwargs
,
"
llm
"
):
kwargs
.
llm
=
"
mlx-lm
"
if
hasattr
(
kwargs
,
"
tts
"
):
kwargs
.
tts
=
"
melo
"
optimal_mac_settings
(
module_kwargs
.
local_mac_optimal_settings
,
module_kwargs
,
)
def
optimal_mac_settings
(
mac_optimal_settings
:
Optional
[
str
],
*
handler_kwargs
):
if
mac_optimal_settings
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
device
"
):
kwargs
.
device
=
"
mps
"
if
hasattr
(
kwargs
,
"
mode
"
):
kwargs
.
mode
=
"
local
"
if
hasattr
(
kwargs
,
"
stt
"
):
kwargs
.
stt
=
"
whisper-mlx
"
if
hasattr
(
kwargs
,
"
llm
"
):
kwargs
.
llm
=
"
mlx-lm
"
if
hasattr
(
kwargs
,
"
tts
"
):
kwargs
.
tts
=
"
melo
"
def
check_mac_settings
(
module_kwargs
):
if
platform
==
"
darwin
"
:
if
platform
==
"
darwin
"
:
if
module_kwargs
.
device
==
"
cuda
"
:
if
module_kwargs
.
device
==
"
cuda
"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -161,46 +134,90 @@ def main():
...
@@ -161,46 +134,90 @@ def main():
"
If you experiences issues generating the voice, considering setting the tts to melo.
"
"
If you experiences issues generating the voice, considering setting the tts to melo.
"
)
)
# 2. Prepare each part's arguments
def
overwrite_device_argument
(
common_device
:
Optional
[
str
],
*
handler_kwargs
):
def
overwrite_device_argument
(
common_device
:
Optional
[
str
],
*
handler_kwargs
):
if
common_device
:
if
common_device
:
for
kwargs
in
handler_kwargs
:
for
kwargs
in
handler_kwargs
:
if
hasattr
(
kwargs
,
"
lm_device
"
):
if
hasattr
(
kwargs
,
"
lm_device
"
):
kwargs
.
lm_device
=
common_device
kwargs
.
lm_device
=
common_device
if
hasattr
(
kwargs
,
"
tts_device
"
):
if
hasattr
(
kwargs
,
"
tts_device
"
):
kwargs
.
tts_device
=
common_device
kwargs
.
tts_device
=
common_device
if
hasattr
(
kwargs
,
"
stt_device
"
):
if
hasattr
(
kwargs
,
"
stt_device
"
):
kwargs
.
stt_device
=
common_device
kwargs
.
stt_device
=
common_device
if
hasattr
(
kwargs
,
"
paraformer_stt_device
"
):
if
hasattr
(
kwargs
,
"
paraformer_stt_device
"
):
kwargs
.
paraformer_stt_device
=
common_device
kwargs
.
paraformer_stt_device
=
common_device
# Call this function with the common device and all the handlers
overwrite_device_argument
(
def
prepare_module_args
(
module_kwargs
,
*
handler_kwargs
):
module_kwargs
.
device
,
optimal_mac_settings
(
module_kwargs
.
local_mac_optimal_settings
,
module_kwargs
)
if
platform
==
"
darwin
"
:
check_mac_settings
(
module_kwargs
)
overwrite_device_argument
(
module_kwargs
.
device
,
*
handler_kwargs
)
def
prepare_all_args
(
module_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
):
prepare_module_args
(
module_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
parler_tts_handler_kwargs
,
whisper_s
tt_handler_kwargs
,
melo_
tt
s
_handler_kwargs
,
paraformer_s
tt_handler_kwargs
,
chat_
tt
s
_handler_kwargs
,
)
)
prepare_args
(
whisper_stt_handler_kwargs
,
"
stt
"
)
rename_args
(
whisper_stt_handler_kwargs
,
"
stt
"
)
prepare_args
(
paraformer_stt_handler_kwargs
,
"
paraformer_stt
"
)
rename_args
(
paraformer_stt_handler_kwargs
,
"
paraformer_stt
"
)
prepare_args
(
language_model_handler_kwargs
,
"
lm
"
)
rename_args
(
language_model_handler_kwargs
,
"
lm
"
)
prepare_args
(
mlx_language_model_handler_kwargs
,
"
mlx_lm
"
)
rename_args
(
mlx_language_model_handler_kwargs
,
"
mlx_lm
"
)
prepare_args
(
parler_tts_handler_kwargs
,
"
tts
"
)
rename_args
(
parler_tts_handler_kwargs
,
"
tts
"
)
prepare_args
(
melo_tts_handler_kwargs
,
"
melo
"
)
rename_args
(
melo_tts_handler_kwargs
,
"
melo
"
)
prepare_args
(
chat_tts_handler_kwargs
,
"
chat_tts
"
)
rename_args
(
chat_tts_handler_kwargs
,
"
chat_tts
"
)
# 3. Build the pipeline
stop_event
=
Event
()
def
initialize_queues_and_events
():
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS
return
{
should_listen
=
Event
()
"
stop_event
"
:
Event
(),
recv_audio_chunks_queue
=
Queue
()
"
should_listen
"
:
Event
(),
send_audio_chunks_queue
=
Queue
()
"
recv_audio_chunks_queue
"
:
Queue
(),
spoken_prompt_queue
=
Queue
()
"
send_audio_chunks_queue
"
:
Queue
(),
text_prompt_queue
=
Queue
()
"
spoken_prompt_queue
"
:
Queue
(),
lm_response_queue
=
Queue
()
"
text_prompt_queue
"
:
Queue
(),
"
lm_response_queue
"
:
Queue
(),
}
def
build_pipeline
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
queues_and_events
,
):
stop_event
=
queues_and_events
[
"
stop_event
"
]
should_listen
=
queues_and_events
[
"
should_listen
"
]
recv_audio_chunks_queue
=
queues_and_events
[
"
recv_audio_chunks_queue
"
]
send_audio_chunks_queue
=
queues_and_events
[
"
send_audio_chunks_queue
"
]
spoken_prompt_queue
=
queues_and_events
[
"
spoken_prompt_queue
"
]
text_prompt_queue
=
queues_and_events
[
"
text_prompt_queue
"
]
lm_response_queue
=
queues_and_events
[
"
lm_response_queue
"
]
if
module_kwargs
.
mode
==
"
local
"
:
if
module_kwargs
.
mode
==
"
local
"
:
from
connections.local_audio_streamer
import
LocalAudioStreamer
from
connections.local_audio_streamer
import
LocalAudioStreamer
...
@@ -238,10 +255,18 @@ def main():
...
@@ -238,10 +255,18 @@ def main():
setup_args
=
(
should_listen
,),
setup_args
=
(
should_listen
,),
setup_kwargs
=
vars
(
vad_handler_kwargs
),
setup_kwargs
=
vars
(
vad_handler_kwargs
),
)
)
stt
=
get_stt_handler
(
module_kwargs
,
stop_event
,
spoken_prompt_queue
,
text_prompt_queue
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
)
lm
=
get_llm_handler
(
module_kwargs
,
stop_event
,
text_prompt_queue
,
lm_response_queue
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
)
tts
=
get_tts_handler
(
module_kwargs
,
stop_event
,
lm_response_queue
,
send_audio_chunks_queue
,
should_listen
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
)
return
ThreadManager
([
*
comms_handlers
,
vad
,
stt
,
lm
,
tts
])
def
get_stt_handler
(
module_kwargs
,
stop_event
,
spoken_prompt_queue
,
text_prompt_queue
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
):
if
module_kwargs
.
stt
==
"
whisper
"
:
if
module_kwargs
.
stt
==
"
whisper
"
:
from
STT.whisper_stt_handler
import
WhisperSTTHandler
from
STT.whisper_stt_handler
import
WhisperSTTHandler
return
WhisperSTTHandler
(
stt
=
WhisperSTTHandler
(
stop_event
,
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
queue_out
=
text_prompt_queue
,
...
@@ -249,8 +274,7 @@ def main():
...
@@ -249,8 +274,7 @@ def main():
)
)
elif
module_kwargs
.
stt
==
"
whisper-mlx
"
:
elif
module_kwargs
.
stt
==
"
whisper-mlx
"
:
from
STT.lightning_whisper_mlx_handler
import
LightningWhisperSTTHandler
from
STT.lightning_whisper_mlx_handler
import
LightningWhisperSTTHandler
return
LightningWhisperSTTHandler
(
stt
=
LightningWhisperSTTHandler
(
stop_event
,
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
queue_out
=
text_prompt_queue
,
...
@@ -258,21 +282,20 @@ def main():
...
@@ -258,21 +282,20 @@ def main():
)
)
elif
module_kwargs
.
stt
==
"
paraformer
"
:
elif
module_kwargs
.
stt
==
"
paraformer
"
:
from
STT.paraformer_handler
import
ParaformerSTTHandler
from
STT.paraformer_handler
import
ParaformerSTTHandler
return
ParaformerSTTHandler
(
stt
=
ParaformerSTTHandler
(
stop_event
,
stop_event
,
queue_in
=
spoken_prompt_queue
,
queue_in
=
spoken_prompt_queue
,
queue_out
=
text_prompt_queue
,
queue_out
=
text_prompt_queue
,
setup_kwargs
=
vars
(
paraformer_stt_handler_kwargs
),
setup_kwargs
=
vars
(
paraformer_stt_handler_kwargs
),
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"
The STT should be either whisper, whisper-mlx, or paraformer.
"
)
"
The STT should be either whisper, whisper-mlx, or paraformer.
"
)
def
get_llm_handler
(
module_kwargs
,
stop_event
,
text_prompt_queue
,
lm_response_queue
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
):
if
module_kwargs
.
llm
==
"
transformers
"
:
if
module_kwargs
.
llm
==
"
transformers
"
:
from
LLM.language_model
import
LanguageModelHandler
from
LLM.language_model
import
LanguageModelHandler
return
LanguageModelHandler
(
lm
=
LanguageModelHandler
(
stop_event
,
stop_event
,
queue_in
=
text_prompt_queue
,
queue_in
=
text_prompt_queue
,
queue_out
=
lm_response_queue
,
queue_out
=
lm_response_queue
,
...
@@ -280,8 +303,7 @@ def main():
...
@@ -280,8 +303,7 @@ def main():
)
)
elif
module_kwargs
.
llm
==
"
mlx-lm
"
:
elif
module_kwargs
.
llm
==
"
mlx-lm
"
:
from
LLM.mlx_language_model
import
MLXLanguageModelHandler
from
LLM.mlx_language_model
import
MLXLanguageModelHandler
return
MLXLanguageModelHandler
(
lm
=
MLXLanguageModelHandler
(
stop_event
,
stop_event
,
queue_in
=
text_prompt_queue
,
queue_in
=
text_prompt_queue
,
queue_out
=
lm_response_queue
,
queue_out
=
lm_response_queue
,
...
@@ -289,10 +311,12 @@ def main():
...
@@ -289,10 +311,12 @@ def main():
)
)
else
:
else
:
raise
ValueError
(
"
The LLM should be either transformers or mlx-lm
"
)
raise
ValueError
(
"
The LLM should be either transformers or mlx-lm
"
)
def
get_tts_handler
(
module_kwargs
,
stop_event
,
lm_response_queue
,
send_audio_chunks_queue
,
should_listen
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
):
if
module_kwargs
.
tts
==
"
parler
"
:
if
module_kwargs
.
tts
==
"
parler
"
:
from
TTS.parler_handler
import
ParlerTTSHandler
from
TTS.parler_handler
import
ParlerTTSHandler
return
ParlerTTSHandler
(
tts
=
ParlerTTSHandler
(
stop_event
,
stop_event
,
queue_in
=
lm_response_queue
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
queue_out
=
send_audio_chunks_queue
,
...
@@ -307,7 +331,7 @@ def main():
...
@@ -307,7 +331,7 @@ def main():
"
Error importing MeloTTSHandler. You might need to run: python -m unidic download
"
"
Error importing MeloTTSHandler. You might need to run: python -m unidic download
"
)
)
raise
e
raise
e
tts
=
MeloTTSHandler
(
return
MeloTTSHandler
(
stop_event
,
stop_event
,
queue_in
=
lm_response_queue
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
queue_out
=
send_audio_chunks_queue
,
...
@@ -320,7 +344,7 @@ def main():
...
@@ -320,7 +344,7 @@ def main():
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
logger
.
error
(
"
Error importing ChatTTSHandler
"
)
logger
.
error
(
"
Error importing ChatTTSHandler
"
)
raise
e
raise
e
tts
=
ChatTTSHandler
(
return
ChatTTSHandler
(
stop_event
,
stop_event
,
queue_in
=
lm_response_queue
,
queue_in
=
lm_response_queue
,
queue_out
=
send_audio_chunks_queue
,
queue_out
=
send_audio_chunks_queue
,
...
@@ -330,14 +354,57 @@ def main():
...
@@ -330,14 +354,57 @@ def main():
else
:
else
:
raise
ValueError
(
"
The TTS should be either parler, melo or chatTTS
"
)
raise
ValueError
(
"
The TTS should be either parler, melo or chatTTS
"
)
# 4. Run the pipeline
def
main
():
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
=
parse_arguments
()
setup_logger
(
module_kwargs
.
log_level
)
prepare_all_args
(
module_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
)
queues_and_events
=
initialize_queues_and_events
()
pipeline_manager
=
build_pipeline
(
module_kwargs
,
socket_receiver_kwargs
,
socket_sender_kwargs
,
vad_handler_kwargs
,
whisper_stt_handler_kwargs
,
paraformer_stt_handler_kwargs
,
language_model_handler_kwargs
,
mlx_language_model_handler_kwargs
,
parler_tts_handler_kwargs
,
melo_tts_handler_kwargs
,
chat_tts_handler_kwargs
,
queues_and_events
,
)
try
:
try
:
pipeline_manager
=
ThreadManager
([
*
comms_handlers
,
vad
,
stt
,
lm
,
tts
])
pipeline_manager
.
start
()
pipeline_manager
.
start
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
pipeline_manager
.
stop
()
pipeline_manager
.
stop
()
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
main
()
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment