Skip to content
Snippets Groups Projects
Commit 8a43a2e6 authored by Eustache Le Bihan's avatar Eustache Le Bihan
Browse files

clean code and add comments

parent 3196799e
No related branches found
No related tags found
No related merge requests found
import logging import logging
import os
import socket import socket
import sys
import threading import threading
from threading import Thread, Event from collections import deque
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue from queue import Queue
from threading import Event, Thread
from time import perf_counter from time import perf_counter
import sys
import os
from pathlib import Path
from dataclasses import dataclass, field
from copy import copy
from collections import deque
import numpy as np import numpy as np
import torch import torch
from nltk.tokenize import sent_tokenize from nltk.tokenize import sent_tokenize
from rich.console import Console from rich.console import Console
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
pipeline, HfArgumentParser,
TextIteratorStreamer, pipeline,
HfArgumentParser TextIteratorStreamer
) )
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration, ParlerTTSForConditionalGeneration,
ParlerTTSStreamer, ParlerTTSStreamer
) )
from utils import ( from utils import (
VADIterator, VADIterator,
int2float, int2float,
next_power_of_2, next_power_of_2
) )
...@@ -44,8 +45,10 @@ torch._inductor.config.fx_graph_cache = True ...@@ -44,8 +45,10 @@ torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15 torch._dynamo.config.cache_size_limit = 15
console = Console() console = Console()
@dataclass @dataclass
class ModuleArguments: class ModuleArguments:
log_level: str = field( log_level: str = field(
...@@ -55,7 +58,12 @@ class ModuleArguments: ...@@ -55,7 +58,12 @@ class ModuleArguments:
} }
) )
class ThreadManager: class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
"""
def __init__(self, handlers): def __init__(self, handlers):
self.handlers = handlers self.handlers = handlers
self.threads = [] self.threads = []
...@@ -72,7 +80,16 @@ class ThreadManager: ...@@ -72,7 +80,16 @@ class ThreadManager:
for thread in self.threads: for thread in self.threads:
thread.join() thread.join()
class BaseHandler: class BaseHandler:
"""
Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
"""
def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
self.stop_event = stop_event self.stop_event = stop_event
self.queue_in = queue_in self.queue_in = queue_in
...@@ -135,6 +152,10 @@ class SocketReceiverArguments: ...@@ -135,6 +152,10 @@ class SocketReceiverArguments:
class SocketReceiver: class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""
def __init__( def __init__(
self, self,
stop_event, stop_event,
...@@ -201,6 +222,10 @@ class SocketSenderArguments: ...@@ -201,6 +222,10 @@ class SocketSenderArguments:
class SocketSender: class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""
def __init__( def __init__(
self, self,
stop_event, stop_event,
...@@ -273,6 +298,11 @@ class VADHandlerArguments: ...@@ -273,6 +298,11 @@ class VADHandlerArguments:
class VADHandler(BaseHandler): class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
to the following part.
"""
def setup( def setup(
self, self,
should_listen, should_listen,
...@@ -284,11 +314,11 @@ class VADHandler(BaseHandler): ...@@ -284,11 +314,11 @@ class VADHandler(BaseHandler):
speech_pad_ms=30, speech_pad_ms=30,
): ):
self._should_listen = should_listen self.should_listen = should_listen
self._sample_rate = sample_rate self.sample_rate = sample_rate
self._min_silence_ms = min_silence_ms self.min_silence_ms = min_silence_ms
self._min_speech_ms = min_speech_ms self.min_speech_ms = min_speech_ms
self._max_speech_ms = max_speech_ms self.max_speech_ms = max_speech_ms
self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad') self.model, _ = torch.hub.load('snakers4/silero-vad', 'silero_vad')
self.iterator = VADIterator( self.iterator = VADIterator(
self.model, self.model,
...@@ -305,8 +335,8 @@ class VADHandler(BaseHandler): ...@@ -305,8 +335,8 @@ class VADHandler(BaseHandler):
if vad_output is not None and len(vad_output) != 0: if vad_output is not None and len(vad_output) != 0:
logger.debug("VAD: end of speech detected") logger.debug("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy() array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / self._sample_rate * 1000 duration_ms = len(array) / self.sample_rate * 1000
if duration_ms < self._min_speech_ms or duration_ms > self._max_speech_ms: if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping") logger.debug(f"audio input of duration: {len(array) / self._sample_rate}s, skipping")
else: else:
self._should_listen.clear() self._should_listen.clear()
...@@ -373,6 +403,10 @@ class WhisperSTTHandlerArguments: ...@@ -373,6 +403,10 @@ class WhisperSTTHandlerArguments:
class WhisperSTTHandler(BaseHandler): class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup( def setup(
self, self,
model_name="distil-whisper/distil-large-v3", model_name="distil-whisper/distil-large-v3",
...@@ -381,16 +415,17 @@ class WhisperSTTHandler(BaseHandler): ...@@ -381,16 +415,17 @@ class WhisperSTTHandler(BaseHandler):
compile_mode=None, compile_mode=None,
gen_kwargs={} gen_kwargs={}
): ):
self.compile_mode=compile_mode
self.processor = AutoProcessor.from_pretrained(model_name)
self.device = device self.device = device
self.torch_dtype = getattr(torch, torch_dtype) self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode=compile_mode
self.gen_kwargs = gen_kwargs
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name, model_name,
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
).to(device) ).to(device)
self.gen_kwargs = gen_kwargs
# compile # compile
if self.compile_mode: if self.compile_mode:
self.model.generation_config.cache_implementation = "static" self.model.generation_config.cache_implementation = "static"
...@@ -402,20 +437,19 @@ class WhisperSTTHandler(BaseHandler): ...@@ -402,20 +437,19 @@ class WhisperSTTHandler(BaseHandler):
spoken_prompt, sampling_rate=16000, return_tensors="pt" spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features ).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype) input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features return input_features
def warmup(self): def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture # 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2 n_steps = 1 if self.compile_mode == "default" else 2
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input = torch.randn( dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000), (1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype, dtype=self.torch_dtype,
device=self.device device=self.device
) )
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
if self.compile_mode not in (None, "default"): if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture # generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
...@@ -427,28 +461,35 @@ class WhisperSTTHandler(BaseHandler): ...@@ -427,28 +461,35 @@ class WhisperSTTHandler(BaseHandler):
else: else:
warmup_gen_kwargs = self.gen_kwargs warmup_gen_kwargs = self.gen_kwargs
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record() start_event.record()
for _ in range(n_steps): for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs) _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
end_event.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
def process(self, spoken_prompt): def process(self, spoken_prompt):
logger.debug("infering whisper...")
global pipeline_start global pipeline_start
pipeline_start = perf_counter() pipeline_start = perf_counter()
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt" input_features = self.prepare_model_inputs(spoken_prompt)
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
logger.debug("infering whisper...")
pred_ids = self.model.generate(input_features, **self.gen_kwargs) pred_ids = self.model.generate(input_features, **self.gen_kwargs)
pred_text = self.processor.batch_decode( pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, pred_ids,
skip_special_tokens=True,
decode_with_timestamps=False decode_with_timestamps=False
)[0] )[0]
logger.debug("finished whisper inference") logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}") console.print(f"[yellow]USER: {pred_text}")
yield pred_text yield pred_text
...@@ -509,6 +550,10 @@ class LanguageModelHandlerArguments: ...@@ -509,6 +550,10 @@ class LanguageModelHandlerArguments:
class Chat: class Chat:
"""
Handles the chat using a circular buffer to avoid OOM issues.
"""
def __init__(self, size): def __init__(self, size):
self.init_chat_message = None self.init_chat_message = None
self.buffer = deque(maxlen=size) self.buffer = deque(maxlen=size)
...@@ -527,25 +572,30 @@ class Chat: ...@@ -527,25 +572,30 @@ class Chat:
class LanguageModelHandler(BaseHandler): class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup( def setup(
self, self,
model_name="microsoft/Phi-3-mini-4k-instruct", model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda", device="cuda",
torch_dtype="float16", torch_dtype="float16",
chat_size=3,
gen_kwargs={}, gen_kwargs={},
user_role="user", user_role="user",
chat_size=3,
init_chat_role=None, init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.", init_chat_prompt="You are a helpful AI assistant.",
): ):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype) self.torch_dtype = getattr(torch, torch_dtype)
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=True trust_remote_code=True
).to(device) ).to(device)
self.device = device
self.pipe = pipeline( self.pipe = pipeline(
"text-generation", "text-generation",
model=self.model, model=self.model,
...@@ -556,6 +606,12 @@ class LanguageModelHandler(BaseHandler): ...@@ -556,6 +606,12 @@ class LanguageModelHandler(BaseHandler):
skip_prompt=True, skip_prompt=True,
skip_special_tokens=True, skip_special_tokens=True,
) )
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs
}
self.chat = Chat(chat_size) self.chat = Chat(chat_size)
if init_chat_role: if init_chat_role:
if not init_chat_prompt: if not init_chat_prompt:
...@@ -563,26 +619,12 @@ class LanguageModelHandler(BaseHandler): ...@@ -563,26 +619,12 @@ class LanguageModelHandler(BaseHandler):
self.chat.init_chat( self.chat.init_chat(
{"role": init_chat_role, "content": init_chat_prompt} {"role": init_chat_role, "content": init_chat_prompt}
) )
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs
}
self.user_role = user_role self.user_role = user_role
self.warmup() self.warmup()
def warmup(self): def warmup(self):
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 2
logger.info(f"Warming up {self.__class__.__name__}") logger.info(f"Warming up {self.__class__.__name__}")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
dummy_input_text = "Write me a poem about Machine Learning." dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
...@@ -592,25 +634,33 @@ class LanguageModelHandler(BaseHandler): ...@@ -592,25 +634,33 @@ class LanguageModelHandler(BaseHandler):
**self.gen_kwargs **self.gen_kwargs
} }
n_steps = 2
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record() start_event.record()
for _ in range(n_steps): for _ in range(n_steps):
thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs) thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs)
thread.start() thread.start()
for _ in self.streamer: for _ in self.streamer:
pass pass
end_event.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
def process(self, prompt): def process(self, prompt):
logger.debug("infering language model...")
self.chat.append( self.chat.append(
{"role": self.user_role, "content": prompt} {"role": self.user_role, "content": prompt}
) )
thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs) thread = Thread(target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs)
thread.start() thread.start()
generated_text, printable_text = "", "" generated_text, printable_text = "", ""
logger.debug("infering language model...")
for new_text in self.streamer: for new_text in self.streamer:
generated_text += new_text generated_text += new_text
printable_text += new_text printable_text += new_text
...@@ -618,9 +668,11 @@ class LanguageModelHandler(BaseHandler): ...@@ -618,9 +668,11 @@ class LanguageModelHandler(BaseHandler):
if len(sentences) > 1: if len(sentences) > 1:
yield(sentences[0]) yield(sentences[0])
printable_text = new_text printable_text = new_text
self.chat.append( self.chat.append(
{"role": "assistant", "content": generated_text} {"role": "assistant", "content": generated_text}
) )
# don't forget last sentence # don't forget last sentence
yield printable_text yield printable_text
...@@ -689,37 +741,33 @@ class ParlerTTSHandler(BaseHandler): ...@@ -689,37 +741,33 @@ class ParlerTTSHandler(BaseHandler):
model_name="ylacombe/parler-tts-mini-jenny-30H", model_name="ylacombe/parler-tts-mini-jenny-30H",
device="cuda", device="cuda",
torch_dtype="float16", torch_dtype="float16",
max_prompt_pad_length=8,
gen_kwargs={},
compile_mode=None, compile_mode=None,
gen_kwargs={},
max_prompt_pad_length=8,
description=( description=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast." "She speaks very fast."
), ),
play_steps_s=1 play_steps_s=1
): ):
self.should_listen = should_listen
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.gen_kwargs = gen_kwargs
self.compile_mode = compile_mode
self.max_prompt_pad_length = max_prompt_pad_length self.max_prompt_pad_length = max_prompt_pad_length
torch_dtype = getattr(torch, torch_dtype) self.description = description
self._should_listen = should_listen
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained( self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype torch_dtype=self.torch_dtype
).to(device) ).to(device)
self.device = device
self.torch_dtype = torch_dtype
self.description = description
self.gen_kwargs = gen_kwargs
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
framerate = self.model.audio_encoder.config.frame_rate framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s) self.play_steps = int(framerate * play_steps_s)
self.compile_mode = compile_mode
if self.compile_mode not in (None, "default"): if self.compile_mode not in (None, "default"):
logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'") logger.warning("Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'")
self.compile_mode = "default" self.compile_mode = "default"
...@@ -727,6 +775,7 @@ class ParlerTTSHandler(BaseHandler): ...@@ -727,6 +775,7 @@ class ParlerTTSHandler(BaseHandler):
if self.compile_mode: if self.compile_mode:
self.model.generation_config.cache_implementation = "static" self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True) self.model.forward = torch.compile(self.model.forward, mode=self.compile_mode, fullgraph=True)
self.warmup() self.warmup()
def prepare_model_inputs( def prepare_model_inputs(
...@@ -752,29 +801,40 @@ class ParlerTTSHandler(BaseHandler): ...@@ -752,29 +801,40 @@ class ParlerTTSHandler(BaseHandler):
"prompt_attention_mask": prompt_attention_mask, "prompt_attention_mask": prompt_attention_mask,
**self.gen_kwargs **self.gen_kwargs
} }
return gen_kwargs return gen_kwargs
def warmup(self): def warmup(self):
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] logger.info(f"Warming up {self.__class__.__name__}")
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs( start_event = torch.cuda.Event(enable_timing=True)
"dummy prompt", end_event = torch.cuda.Event(enable_timing=True)
max_length_prompt=pad_length,
pad=True # 2 warmup steps for no compile or compile mode with CUDA graphs capture
) n_steps = 1 if self.compile_mode == "default" else 2
# 2 warmup steps for modes that capture CUDA graphs
n_steps = 1 if self.compile_mode == "default" else 2 torch.cuda.synchronize()
start_event.record()
logger.info(f"Warming up length {pad_length} tokens...") if self.compile_mode:
start_event = torch.cuda.Event(enable_timing=True) pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
end_event = torch.cuda.Event(enable_timing=True) for pad_length in pad_lengths[::-1]:
torch.cuda.synchronize() model_kwargs = self.prepare_model_inputs(
start_event.record() "dummy prompt",
max_length_prompt=pad_length,
pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps): for _ in range(n_steps):
_ = self.model.generate(**model_kwargs) _ = self.model.generate(**model_kwargs)
end_event.record()
torch.cuda.synchronize() end_event.record()
logger.info(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") torch.cuda.synchronize()
logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
def process(self, llm_sentence): def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}") console.print(f"[green]ASSISTANT: {llm_sentence}")
...@@ -808,10 +868,14 @@ class ParlerTTSHandler(BaseHandler): ...@@ -808,10 +868,14 @@ class ParlerTTSHandler(BaseHandler):
audio_chunk = np.int16(audio_chunk * 32767) audio_chunk = np.int16(audio_chunk * 32767)
yield audio_chunk yield audio_chunk
self._should_listen.set() self.should_listen.set()
def prepare_args(args, prefix): def prepare_args(args, prefix):
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
"""
gen_kwargs = {} gen_kwargs = {}
for key in copy(args.__dict__): for key in copy(args.__dict__):
if key.startswith(prefix): if key.startswith(prefix):
...@@ -860,6 +924,7 @@ def main(): ...@@ -860,6 +924,7 @@ def main():
parler_tts_handler_kwargs, parler_tts_handler_kwargs,
) = parser.parse_args_into_dataclasses() ) = parser.parse_args_into_dataclasses()
# 1. Handle logger
global logger global logger
logging.basicConfig( logging.basicConfig(
level=module_kwargs.log_level.upper(), level=module_kwargs.log_level.upper(),
...@@ -871,12 +936,15 @@ def main(): ...@@ -871,12 +936,15 @@ def main():
if module_kwargs.log_level == "debug": if module_kwargs.log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
# 2. Prepare each part's arguments
prepare_args(whisper_stt_handler_kwargs, "stt") prepare_args(whisper_stt_handler_kwargs, "stt")
prepare_args(language_model_handler_kwargs, "lm") prepare_args(language_model_handler_kwargs, "lm")
prepare_args(parler_tts_handler_kwargs, "tts") prepare_args(parler_tts_handler_kwargs, "tts")
# 3. Build the pipeline
stop_event = Event() stop_event = Event()
should_listen = Event() # used to stop putting received audio chunks in queue until all setences have been processed by the TTS
should_listen = Event()
recv_audio_chunks_queue = Queue() recv_audio_chunks_queue = Queue()
send_audio_chunks_queue = Queue() send_audio_chunks_queue = Queue()
spoken_prompt_queue = Queue() spoken_prompt_queue = Queue()
...@@ -926,6 +994,7 @@ def main(): ...@@ -926,6 +994,7 @@ def main():
port=socket_sender_kwargs.send_port, port=socket_sender_kwargs.send_port,
) )
# 4. Run the pipeline
try: try:
pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler]) pipeline_manager = ThreadManager([vad, tts, lm, stt, recv_handler, send_handler])
pipeline_manager.start() pipeline_manager.start()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment