Skip to content
Snippets Groups Projects
Commit d31d6654 authored by Andres Marafioti's avatar Andres Marafioti
Browse files

add mlx lm to make it go BRRRR

parent ea2d73cc
No related branches found
No related tags found
No related merge requests found
class Chat:
"""
Handles the chat using to avoid OOM issues.
"""
def __init__(self, size):
self.size = size
self.init_chat_message = None
# maxlen is necessary pair, since a each new step we add an prompt and assitant answer
self.buffer = []
def append(self, item):
self.buffer.append(item)
if len(self.buffer) == 2 * (self.size + 1):
self.buffer.pop(0)
self.buffer.pop(0)
def init_chat(self, init_chat_message):
self.init_chat_message = init_chat_message
def to_list(self):
if self.init_chat_message:
return [self.init_chat_message] + self.buffer
else:
return self.buffer
import logging
from LLM.chat import Chat
from baseHandler import BaseHandler
from mlx_lm import load, stream_generate, generate
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class MLXLanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="microsoft/Phi-3-mini-4k-instruct",
device="mps",
torch_dtype="float16",
gen_kwargs={},
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.model_name = model_name
model_id = 'microsoft/Phi-3-mini-4k-instruct'
self.model, self.tokenizer = load(model_id)
self.gen_kwargs = gen_kwargs
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
n_steps = 2
for _ in range(n_steps):
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True)
output = ""
curr_output = ""
for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]):
output += t
curr_output += t
if curr_output.endswith(('.', '?', '!', '<|end|>')):
yield curr_output.replace('<|end|>', '')
curr_output = ""
generated_text = output.replace('<|end|>', '')
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
......@@ -2,4 +2,5 @@ nltk==3.8.1
parler_tts @ git+https://github.com/huggingface/parler-tts.git
torch==2.4.0
sounddevice==0.5.0
lightning-whisper-mlx==0.0.10
\ No newline at end of file
lightning-whisper-mlx==0.0.10
mlx-lm==0.17.0
\ No newline at end of file
......@@ -11,6 +11,7 @@ from threading import Event, Thread
from time import perf_counter
from typing import Optional
from LLM.mlx_lm import MLXLanguageModelHandler
from TTS.melotts import MeloTTSHandler
from baseHandler import BaseHandler
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
......@@ -484,13 +485,13 @@ class LanguageModelHandlerArguments:
},
)
init_chat_role: str = field(
default=None,
default='system',
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
init_chat_prompt: str = field(
default="You are a helpful AI assistant.",
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
......@@ -514,7 +515,7 @@ class LanguageModelHandlerArguments:
},
)
chat_size: int = field(
default=1,
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
......@@ -1015,7 +1016,7 @@ def main():
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
lm = LanguageModelHandler(
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment