Skip to content
Snippets Groups Projects
Unverified Commit d3d25c45 authored by Andrés Marafioti's avatar Andrés Marafioti Committed by GitHub
Browse files

Merge pull request #43 from huggingface/refactor-handlers

refactor all the handlers - folder structure
parents f72806da d50687a0
No related branches found
No related tags found
No related merge requests found
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
TextIteratorStreamer,
)
import torch
from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging
from nltk import sent_tokenize
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda",
torch_dtype="float16",
gen_kwargs={},
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
self.pipe = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
)
self.streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs,
}
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
n_steps = 2
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
thread = Thread(
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
)
thread.start()
for _ in self.streamer:
pass
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
)
thread.start()
if self.device == "mps":
generated_text = ""
for new_text in self.streamer:
generated_text += new_text
printable_text = generated_text
torch.mps.empty_cache()
else:
generated_text, printable_text = "", ""
for new_text in self.streamer:
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
...@@ -66,13 +66,15 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -66,13 +66,15 @@ class MLXLanguageModelHandler(BaseHandler):
logger.debug("infering language model...") logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
# Remove system messages if using a Gemma model # Remove system messages if using a Gemma model
if "gemma" in self.model_name.lower(): if "gemma" in self.model_name.lower():
chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"] chat_messages = [
msg for msg in self.chat.to_list() if msg["role"] != "system"
]
else: else:
chat_messages = self.chat.to_list() chat_messages = self.chat.to_list()
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True chat_messages, tokenize=False, add_generation_prompt=True
) )
......
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
)
import torch
from baseHandler import BaseHandler
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup(
self,
model_name="distil-whisper/distil-large-v3",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
).to(device)
# compile
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(self, spoken_prompt):
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device,
)
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
else:
warmup_gen_kwargs = self.gen_kwargs
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, spoken_prompt):
logger.debug("infering whisper...")
global pipeline_start
pipeline_start = perf_counter()
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text
...@@ -24,7 +24,6 @@ class MeloTTSHandler(BaseHandler): ...@@ -24,7 +24,6 @@ class MeloTTSHandler(BaseHandler):
gen_kwargs={}, # Unused gen_kwargs={}, # Unused
blocksize=512, blocksize=512,
): ):
print(device)
self.should_listen = should_listen self.should_listen = should_listen
self.device = device self.device = device
self.model = TTS(language=language, device=device) self.model = TTS(language=language, device=device)
......
from threading import Thread
from time import perf_counter
from baseHandler import BaseHandler
import numpy as np
import torch
from transformers import (
AutoTokenizer,
)
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
import librosa
import logging
from rich.console import Console
from utils.utils import next_power_of_2
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ParlerTTSHandler(BaseHandler):
def setup(
self,
should_listen,
model_name="ylacombe/parler-tts-mini-jenny-30H",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
max_prompt_pad_length=8,
description=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
play_steps_s=1,
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.gen_kwargs = gen_kwargs
self.compile_mode = compile_mode
self.max_prompt_pad_length = max_prompt_pad_length
self.description = description
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name, torch_dtype=self.torch_dtype
).to(device)
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
self.blocksize = blocksize
if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
)
self.compile_mode = "default"
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(
self,
prompt,
max_length_prompt=50,
pad=False,
):
pad_args_prompt = (
{"padding": "max_length", "max_length": max_length_prompt} if pad else {}
)
tokenized_description = self.description_tokenizer(
self.description, return_tensors="pt"
)
input_ids = tokenized_description.input_ids.to(self.device)
attention_mask = tokenized_description.attention_mask.to(self.device)
tokenized_prompt = self.prompt_tokenizer(
prompt, return_tensors="pt", **pad_args_prompt
)
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
**self.gen_kwargs,
}
return gen_kwargs
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
if self.device == "cuda":
torch.cuda.synchronize()
start_event.record()
if self.compile_mode:
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs(
"dummy prompt", max_length_prompt=pad_length, pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
pad_args = {}
if self.compile_mode:
# pad to closest upper power of two
pad_length = next_power_of_2(nb_tokens)
logger.debug(f"padding to {pad_length}")
pad_args["pad"] = True
pad_args["max_length_prompt"] = pad_length
tts_gen_kwargs = self.prepare_model_inputs(
llm_sentence,
**pad_args,
)
streamer = ParlerTTSStreamer(
self.model, device=self.device, play_steps=self.play_steps
)
tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
torch.manual_seed(0)
thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
thread.start()
for i, audio_chunk in enumerate(streamer):
global pipeline_start
if i == 0 and "pipeline_start" in globals():
logger.info(
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
self.should_listen.set()
from VAD.vad_iterator import VADIterator
from baseHandler import BaseHandler
import numpy as np
import torch
from rich.console import Console
from utils.utils import int2float
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
to the following part.
"""
def setup(
self,
should_listen,
thresh=0.3,
sample_rate=16000,
min_silence_ms=1000,
min_speech_ms=500,
max_speech_ms=float("inf"),
speech_pad_ms=30,
):
self.should_listen = should_listen
self.sample_rate = sample_rate
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.max_speech_ms = max_speech_ms
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
self.iterator = VADIterator(
self.model,
threshold=thresh,
sampling_rate=sample_rate,
min_silence_duration_ms=min_silence_ms,
speech_pad_ms=speech_pad_ms,
)
def process(self, audio_chunk):
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
audio_float32 = int2float(audio_int16)
vad_output = self.iterator(torch.from_numpy(audio_float32))
if vad_output is not None and len(vad_output) != 0:
logger.debug("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / self.sample_rate * 1000
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
logger.debug(
f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
)
else:
self.should_listen.clear()
logger.debug("Stop listening")
yield array
import numpy as np
import torch import torch
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound
class VADIterator: class VADIterator:
def __init__( def __init__(
self, self,
......
File moved
import socket
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""
def __init__(
self,
stop_event,
queue_out,
should_listen,
host="0.0.0.0",
port=12345,
chunk_size=1024,
):
self.stop_event = stop_event
self.queue_out = queue_out
self.should_listen = should_listen
self.chunk_size = chunk_size
self.host = host
self.port = port
def receive_full_chunk(self, conn, chunk_size):
data = b""
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
# connection closed
return None
data += packet
return data
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Receiver waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("receiver connected")
self.should_listen.set()
while not self.stop_event.is_set():
audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
if audio_chunk is None:
# connection closed
self.queue_out.put(b"END")
break
if self.should_listen.is_set():
self.queue_out.put(audio_chunk)
self.conn.close()
logger.info("Receiver closed")
import socket
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""
def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
self.host = host
self.port = port
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Sender waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("sender connected")
while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
self.conn.close()
logger.info("Sender closed")
This diff is collapsed.
import threading
class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
"""
def __init__(self, handlers):
self.handlers = handlers
self.threads = []
def start(self):
for handler in self.handlers:
thread = threading.Thread(target=handler.run)
self.threads.append(thread)
thread.start()
def stop(self):
for handler in self.handlers:
handler.stop_event.set()
for thread in self.threads:
thread.join()
import numpy as np
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment