Skip to content
Snippets Groups Projects
Unverified Commit e127cc7a authored by wuhongsheng's avatar wuhongsheng Committed by GitHub
Browse files

Merge pull request #1 from eustlb/open-api-fix

Open api fix
parents d6b0941f 506e61e2
No related branches found
No related tags found
No related merge requests found
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
ENV PYTHONUNBUFFERED 1
WORKDIR /usr/src/app
# Install packages
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
\ No newline at end of file
......@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class MLXLanguageModelHandler(BaseHandler):
"""
......@@ -44,7 +52,7 @@ class MLXLanguageModelHandler(BaseHandler):
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_input_text = "Repeat the word 'home'."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
n_steps = 2
......@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
......@@ -86,9 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
self.chat.append({"role": "assistant", "content": generated_text})
\ No newline at end of file
......@@ -54,37 +54,36 @@ class OpenApiModelHandler(BaseHandler):
logger.info(
f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s"
)
def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
def process(self, prompt):
logger.debug("call api language model...")
self.chat.append({"role": self.user_role, "content": prompt})
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
if self.stream:
generated_text, printable_text = "", ""
for chunk in response:
new_text = chunk.choices[0].delta.content or ""
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield sentences[0], language_code
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text, language_code
else:
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text, language_code
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": self.user_role, "content": prompt},
],
stream=self.stream
)
if self.stream:
generated_text, printable_text = "", ""
for chunk in response:
new_text = chunk.choices[0].delta.content or ""
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield sentences[0], language_code
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text, language_code
else:
generated_text = response.choices[0].message.content
self.chat.append({"role": "assistant", "content": generated_text})
yield generated_text, language_code
......@@ -14,7 +14,7 @@
* [Usage](#usage)
- [Docker Server approach](#docker-server)
- [Server/Client approach](#serverclient-approach)
- [Local approach](#local-approach)
- [Local approach](#local-approach-running-on-mac)
* [Command-line usage](#command-line-usage)
- [Model parameters](#model-parameters)
- [Generation parameters](#generation-parameters)
......@@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install
### Server/Client Approach
To run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
1. Run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
Then run the client locally to handle sending microphone input and receiving generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
2. Run the client locally to handle microphone input and receive generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
### Running on Mac
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
### Local Approach (Mac)
1. For optimal settings on Mac:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
You can also pass `--device mps` to have all the models set to device mps.
The local mac optimal settings set the mode to be local as explained above and change the models to:
- LightningWhisperMLX
- MLX LM
- MeloTTS
This setting:
- Adds `--device mps` to use MPS for all models.
- Sets LightningWhisperMLX for STT
- Sets MLX LM for language model
- Sets MeloTTS for TTS
### Recommended usage with Cuda
......@@ -117,6 +118,57 @@ python s2s_pipeline.py \
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
### Multi-language Support
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
#### With the server version:
For automatic language detection:
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
#### Local Mac Setup
For automatic language detection:
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
## Command-line Usage
### Model Parameters
......
......@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
from copy import copy
import torch
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class LightningWhisperSTTHandler(BaseHandler):
"""
......@@ -19,15 +29,19 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup(
self,
model_name="distil-large-v3",
device="cuda",
device="mps",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.start_language = language
self.last_language = language
self.warmup()
def warmup(self):
......@@ -46,10 +60,26 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start
pipeline_start = perf_counter()
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code
pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)
......@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"en": "EN",
"fr": "FR",
"es": "ES",
"zh": "ZH",
......@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"en": "EN-BR",
"fr": "FR",
"es": "ES",
"zh": "ZH",
......
......@@ -68,7 +68,7 @@ class ParlerTTSHandler(BaseHandler):
if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
"Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
)
self.compile_mode = "default"
......@@ -147,6 +147,9 @@ class ParlerTTSHandler(BaseHandler):
)
def process(self, llm_sentence):
if isinstance(llm_sentence, tuple):
llm_sentence, _ = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
......
......@@ -86,3 +86,7 @@ class VADHandler(BaseHandler):
)
array = enhanced.numpy().squeeze()
yield array
@property
def min_time_to_debug(self):
return 0.00001
......@@ -42,6 +42,6 @@ class VADHandlerArguments:
audio_enhancement: bool = field(
default=False,
metadata={
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True."
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
},
)
......@@ -36,7 +36,8 @@ class BaseHandler:
start_time = perf_counter()
for output in self.process(input):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
if self.last_time > self.min_time_to_debug:
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
self.queue_out.put(output)
start_time = perf_counter()
......@@ -46,6 +47,10 @@ class BaseHandler:
@property
def last_time(self):
return self._times[-1]
@property
def min_time_to_debug(self):
return 0.001
def cleanup(self):
pass
......@@ -4,6 +4,7 @@ services:
pipeline:
build:
context: .
dockerfile: ${DOCKERFILE:-Dockerfile}
command:
- python3
- s2s_pipeline.py
......
......@@ -50,11 +50,10 @@ console = Console()
logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs
def prepare_args(args, prefix):
def rename_args(args, prefix):
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
"""
gen_kwargs = {}
for key in copy(args.__dict__):
if key.startswith(prefix):
......@@ -68,7 +67,7 @@ def prepare_args(args, prefix):
args.__dict__["gen_kwargs"] = gen_kwargs
def main():
def parse_arguments():
parser = HfArgumentParser(
(
ModuleArguments,
......@@ -86,71 +85,43 @@ def main():
)
)
# 0. Parse CLI arguments
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# Parse configurations from a JSON file if specified
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
# Parse arguments from command line if no JSON file is provided
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses()
# 1. Handle logger
return parser.parse_args_into_dataclasses()
def setup_logger(log_level):
global logger
logging.basicConfig(
level=module_kwargs.log_level.upper(),
level=log_level.upper(),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# torch compile logs
if module_kwargs.log_level == "debug":
if log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
optimal_mac_settings(
module_kwargs.local_mac_optimal_settings,
module_kwargs,
)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
if hasattr(kwargs, "device"):
kwargs.device = "mps"
if hasattr(kwargs, "mode"):
kwargs.mode = "local"
if hasattr(kwargs, "stt"):
kwargs.stt = "whisper-mlx"
if hasattr(kwargs, "llm"):
kwargs.llm = "mlx-lm"
if hasattr(kwargs, "tts"):
kwargs.tts = "melo"
def check_mac_settings(module_kwargs):
if platform == "darwin":
if module_kwargs.device == "cuda":
raise ValueError(
......@@ -165,48 +136,95 @@ def main():
"If you experiences issues generating the voice, considering setting the tts to melo."
)
# 2. Prepare each part's arguments
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
# Call this function with the common device and all the handlers
overwrite_device_argument(
module_kwargs.device,
def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
if common_device:
for kwargs in handler_kwargs:
if hasattr(kwargs, "lm_device"):
kwargs.lm_device = common_device
if hasattr(kwargs, "tts_device"):
kwargs.tts_device = common_device
if hasattr(kwargs, "stt_device"):
kwargs.stt_device = common_device
if hasattr(kwargs, "paraformer_stt_device"):
kwargs.paraformer_stt_device = common_device
def prepare_module_args(module_kwargs, *handler_kwargs):
optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs)
if platform == "darwin":
check_mac_settings(module_kwargs)
overwrite_device_argument(module_kwargs.device, *handler_kwargs)
def prepare_all_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
):
prepare_module_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
)
prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
prepare_args(language_model_handler_kwargs, "lm")
prepare_args(open_api_language_model_handler_kwargs, "open_api")
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline
stop_event = Event()
# used to stop putting received audio chunks in queue until all setences have been processed by the TTS
should_listen = Event()
recv_audio_chunks_queue = Queue()
send_audio_chunks_queue = Queue()
spoken_prompt_queue = Queue()
text_prompt_queue = Queue()
lm_response_queue = Queue()
rename_args(whisper_stt_handler_kwargs, "stt")
rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(language_model_handler_kwargs, "lm")
rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(open_api_language_model_handler_kwargs, "open_api")
rename_args(parler_tts_handler_kwargs, "tts")
rename_args(melo_tts_handler_kwargs, "melo")
rename_args(chat_tts_handler_kwargs, "chat_tts")
def initialize_queues_and_events():
return {
"stop_event": Event(),
"should_listen": Event(),
"recv_audio_chunks_queue": Queue(),
"send_audio_chunks_queue": Queue(),
"spoken_prompt_queue": Queue(),
"text_prompt_queue": Queue(),
"lm_response_queue": Queue(),
}
def build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
queues_and_events,
):
stop_event = queues_and_events["stop_event"]
should_listen = queues_and_events["should_listen"]
recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"]
send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"]
spoken_prompt_queue = queues_and_events["spoken_prompt_queue"]
text_prompt_queue = queues_and_events["text_prompt_queue"]
lm_response_queue = queues_and_events["lm_response_queue"]
if module_kwargs.mode == "local":
from connections.local_audio_streamer import LocalAudioStreamer
......@@ -243,10 +261,18 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(vad_handler_kwargs),
)
stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs)
return ThreadManager([*comms_handlers, vad, stt, lm, tts])
def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs):
if module_kwargs.stt == "whisper":
from STT.whisper_stt_handler import WhisperSTTHandler
stt = WhisperSTTHandler(
return WhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
......@@ -254,8 +280,7 @@ def main():
)
elif module_kwargs.stt == "whisper-mlx":
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
stt = LightningWhisperSTTHandler(
return LightningWhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
......@@ -263,31 +288,36 @@ def main():
)
elif module_kwargs.stt == "paraformer":
from STT.paraformer_handler import ParaformerSTTHandler
stt = ParaformerSTTHandler(
return ParaformerSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(paraformer_stt_handler_kwargs),
)
else:
raise ValueError(
"The STT should be either whisper, whisper-mlx, or paraformer."
)
raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
def get_llm_handler(
module_kwargs,
stop_event,
text_prompt_queue,
lm_response_queue,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs
):
if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler
lm = LanguageModelHandler(
return LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == "open_api":
from LLM.openai_api_language_model import OpenApiModelHandler
lm = OpenApiModelHandler(
return OpenApiModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
......@@ -296,7 +326,7 @@ def main():
elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler
lm = MLXLanguageModelHandler(
return MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
......@@ -305,17 +335,18 @@ def main():
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs):
if module_kwargs.tts == "parler":
from TTS.parler_handler import ParlerTTSHandler
tts = ParlerTTSHandler(
return ParlerTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
)
elif module_kwargs.tts == "melo":
try:
from TTS.melo_handler import MeloTTSHandler
......@@ -324,7 +355,7 @@ def main():
"Error importing MeloTTSHandler. You might need to run: python -m unidic download"
)
raise e
tts = MeloTTSHandler(
return MeloTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
......@@ -337,7 +368,7 @@ def main():
except RuntimeError as e:
logger.error("Error importing ChatTTSHandler")
raise e
tts = ChatTTSHandler(
return ChatTTSHandler(
stop_event,
queue_in=lm_response_queue,
queue_out=send_audio_chunks_queue,
......@@ -347,14 +378,60 @@ def main():
else:
raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline
def main():
(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
) = parse_arguments()
setup_logger(module_kwargs.log_level)
prepare_all_args(
module_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
)
queues_and_events = initialize_queues_and_events()
pipeline_manager = build_pipeline(
module_kwargs,
socket_receiver_kwargs,
socket_sender_kwargs,
vad_handler_kwargs,
whisper_stt_handler_kwargs,
paraformer_stt_handler_kwargs,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
queues_and_events,
)
try:
pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
pipeline_manager.start()
except KeyboardInterrupt:
pipeline_manager.stop()
if __name__ == "__main__":
main()
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment