From d31d6654556decd424a2f1203251da7acb25226d Mon Sep 17 00:00:00 2001 From: Andres Marafioti <andimarafioti@gmail.com> Date: Mon, 19 Aug 2024 17:44:29 +0200 Subject: [PATCH] add mlx lm to make it go BRRRR --- LLM/chat.py | 28 ++++++++++++++++++ LLM/mlx_lm.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- s2s_pipeline.py | 9 +++--- 4 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 LLM/chat.py create mode 100644 LLM/mlx_lm.py diff --git a/LLM/chat.py b/LLM/chat.py new file mode 100644 index 0000000..4245830 --- /dev/null +++ b/LLM/chat.py @@ -0,0 +1,28 @@ + + + +class Chat: + """ + Handles the chat using to avoid OOM issues. + """ + + def __init__(self, size): + self.size = size + self.init_chat_message = None + # maxlen is necessary pair, since a each new step we add an prompt and assitant answer + self.buffer = [] + + def append(self, item): + self.buffer.append(item) + if len(self.buffer) == 2 * (self.size + 1): + self.buffer.pop(0) + self.buffer.pop(0) + + def init_chat(self, init_chat_message): + self.init_chat_message = init_chat_message + + def to_list(self): + if self.init_chat_message: + return [self.init_chat_message] + self.buffer + else: + return self.buffer diff --git a/LLM/mlx_lm.py b/LLM/mlx_lm.py new file mode 100644 index 0000000..ff63b71 --- /dev/null +++ b/LLM/mlx_lm.py @@ -0,0 +1,75 @@ +import logging +from LLM.chat import Chat +from baseHandler import BaseHandler +from mlx_lm import load, stream_generate, generate +from rich.console import Console +import torch +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +console = Console() + +class MLXLanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + + def setup( + self, + model_name="microsoft/Phi-3-mini-4k-instruct", + device="mps", + torch_dtype="float16", + gen_kwargs={}, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): + self.model_name = model_name + model_id = 'microsoft/Phi-3-mini-4k-instruct' + self.model, self.tokenizer = load(model_id) + self.gen_kwargs = gen_kwargs + + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + + dummy_input_text = "Write me a poem about Machine Learning." + dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] + + n_steps = 2 + + for _ in range(n_steps): + prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) + generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False) + + + def process(self, prompt): + logger.debug("infering language model...") + + self.chat.append({"role": self.user_role, "content": prompt}) + prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True) + output = "" + curr_output = "" + for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]): + output += t + curr_output += t + if curr_output.endswith(('.', '?', '!', '<|end|>')): + yield curr_output.replace('<|end|>', '') + curr_output = "" + generated_text = output.replace('<|end|>', '') + torch.mps.empty_cache() + + self.chat.append({"role": "assistant", "content": generated_text}) diff --git a/requirements.txt b/requirements.txt index b928d18..1a83958 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ nltk==3.8.1 parler_tts @ git+https://github.com/huggingface/parler-tts.git torch==2.4.0 sounddevice==0.5.0 -lightning-whisper-mlx==0.0.10 \ No newline at end of file +lightning-whisper-mlx==0.0.10 +mlx-lm==0.17.0 \ No newline at end of file diff --git a/s2s_pipeline.py b/s2s_pipeline.py index f9692e7..146f175 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -11,6 +11,7 @@ from threading import Event, Thread from time import perf_counter from typing import Optional +from LLM.mlx_lm import MLXLanguageModelHandler from TTS.melotts import MeloTTSHandler from baseHandler import BaseHandler from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler @@ -484,13 +485,13 @@ class LanguageModelHandlerArguments: }, ) init_chat_role: str = field( - default=None, + default='system', metadata={ "help": "Initial role for setting up the chat context. Default is 'system'." }, ) init_chat_prompt: str = field( - default="You are a helpful AI assistant.", + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", metadata={ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" }, @@ -514,7 +515,7 @@ class LanguageModelHandlerArguments: }, ) chat_size: int = field( - default=1, + default=2, metadata={ "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." }, @@ -1015,7 +1016,7 @@ def main(): queue_out=text_prompt_queue, setup_kwargs=vars(whisper_stt_handler_kwargs), ) - lm = LanguageModelHandler( + lm = MLXLanguageModelHandler( stop_event, queue_in=text_prompt_queue, queue_out=lm_response_queue, -- GitLab