diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 8d2cef00aae933c34bf90dcccefe0aa9e929ae01..8132e1abfc94dae7b87641ae0c1d28e7e74689ad 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -9,10 +9,8 @@ import os from pathlib import Path from dataclasses import dataclass, field from copy import copy -import multiprocessing import numpy as np -import soundfile as sf import torch from nltk.tokenize import sent_tokenize from rich.console import Console @@ -72,10 +70,6 @@ class ThreadManager: for thread in self.threads: thread.join() - -pipeline_start = None - - class BaseHandler: def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): self.stop_event = stop_event @@ -170,8 +164,9 @@ class SocketReceiver: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) + logger.info('Receiver waiting to be connected...') self.conn, _ = self.socket.accept() - logger.debug("receiver connected") + logger.info("receiver connected") self.should_listen.set() while not self.stop_event.is_set(): @@ -183,7 +178,7 @@ class SocketReceiver: if self.should_listen.is_set(): self.queue_out.put(audio_chunk) self.conn.close() - logger.debug("Receiver closed") + logger.info("Receiver closed") @dataclass @@ -222,8 +217,9 @@ class SocketSender: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((self.host, self.port)) self.socket.listen(1) + logger.info('Sender waiting to be connected...') self.conn, _ = self.socket.accept() - logger.debug("sender connected") + logger.info("sender connected") while not self.stop_event.is_set(): audio_chunk = self.queue_in.get() @@ -231,7 +227,7 @@ class SocketSender: if isinstance(audio_chunk, bytes) and audio_chunk == b'END': break self.conn.close() - logger.debug("Sender closed") + logger.info("Sender closed") @dataclass @@ -426,6 +422,8 @@ class WhisperSTTHandler(BaseHandler): "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs } + else: + warmup_gen_kwargs = self.gen_kwargs start_event.record() for _ in range(n_steps): @@ -515,12 +513,14 @@ class LanguageModelHandler(BaseHandler): init_chat_role=None, init_chat_prompt="You are a helpful AI assistant.", ): + self.torch_dtype = getattr(torch, torch_dtype) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True ).to(device) + self.device = device self.pipe = pipeline( "text-generation", model=self.model, @@ -544,6 +544,34 @@ class LanguageModelHandler(BaseHandler): **gen_kwargs } self.user_role = user_role + self.warmup() + + def warmup(self): + # 2 warmup steps for no compile or compile mode with CUDA graphs capture + n_steps = 2 + logger.info(f"Warming up {self.__class__.__name__}") + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + dummy_input_text = "Write me a poem about Machine Learning." + dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] + warmup_gen_kwargs = { + "min_new_tokens": self.gen_kwargs["max_new_tokens"], + "max_new_tokens": self.gen_kwargs["max_new_tokens"], + **self.gen_kwargs + } + + start_event.record() + for _ in range(n_steps): + thread = Thread(target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs) + thread.start() + for _ in self.streamer: + pass + + end_event.record() + torch.cuda.synchronize() + logger.info(f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") def process(self, prompt): self.chat.append( @@ -668,7 +696,7 @@ class ParlerTTSHandler(BaseHandler): for i, audio_chunk in enumerate(streamer): if i == 0: - logger.debug(f"time to first audio: {perf_counter() - pipeline_start:.3f}") + logger.info(f"Time to first audio: {perf_counter() - pipeline_start:.3f}") audio_chunk = np.int16(audio_chunk * 32767) yield audio_chunk