From d31d6654556decd424a2f1203251da7acb25226d Mon Sep 17 00:00:00 2001
From: Andres Marafioti <andimarafioti@gmail.com>
Date: Mon, 19 Aug 2024 17:44:29 +0200
Subject: [PATCH] add mlx lm to make it go BRRRR

---
 LLM/chat.py      | 28 ++++++++++++++++++
 LLM/mlx_lm.py    | 75 ++++++++++++++++++++++++++++++++++++++++++++++++
 requirements.txt |  3 +-
 s2s_pipeline.py  |  9 +++---
 4 files changed, 110 insertions(+), 5 deletions(-)
 create mode 100644 LLM/chat.py
 create mode 100644 LLM/mlx_lm.py

diff --git a/LLM/chat.py b/LLM/chat.py
new file mode 100644
index 0000000..4245830
--- /dev/null
+++ b/LLM/chat.py
@@ -0,0 +1,28 @@
+
+
+
+class Chat:
+    """
+    Handles the chat using to avoid OOM issues.
+    """
+
+    def __init__(self, size):
+        self.size = size
+        self.init_chat_message = None
+        # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
+        self.buffer = []
+
+    def append(self, item):
+        self.buffer.append(item)
+        if len(self.buffer) == 2 * (self.size + 1):
+            self.buffer.pop(0)
+            self.buffer.pop(0)
+
+    def init_chat(self, init_chat_message):
+        self.init_chat_message = init_chat_message
+
+    def to_list(self):
+        if self.init_chat_message:
+            return [self.init_chat_message] + self.buffer
+        else:
+            return self.buffer
diff --git a/LLM/mlx_lm.py b/LLM/mlx_lm.py
new file mode 100644
index 0000000..ff63b71
--- /dev/null
+++ b/LLM/mlx_lm.py
@@ -0,0 +1,75 @@
+import logging
+from LLM.chat import Chat
+from baseHandler import BaseHandler
+from mlx_lm import load, stream_generate, generate
+from rich.console import Console
+import torch
+logging.basicConfig(
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+class MLXLanguageModelHandler(BaseHandler):
+    """
+    Handles the language model part.
+    """
+
+    def setup(
+        self,
+        model_name="microsoft/Phi-3-mini-4k-instruct",
+        device="mps",
+        torch_dtype="float16",
+        gen_kwargs={},
+        user_role="user",
+        chat_size=1,
+        init_chat_role=None,
+        init_chat_prompt="You are a helpful AI assistant.",
+    ):
+        self.model_name = model_name
+        model_id = 'microsoft/Phi-3-mini-4k-instruct'
+        self.model, self.tokenizer = load(model_id)
+        self.gen_kwargs = gen_kwargs
+
+        self.chat = Chat(chat_size)
+        if init_chat_role:
+            if not init_chat_prompt:
+                raise ValueError(
+                    "An initial promt needs to be specified when setting init_chat_role."
+                )
+            self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
+        self.user_role = user_role
+
+        self.warmup()
+
+    def warmup(self):
+        logger.info(f"Warming up {self.__class__.__name__}")
+
+        dummy_input_text = "Write me a poem about Machine Learning."
+        dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
+        
+        n_steps = 2
+
+        for _ in range(n_steps):
+            prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
+            generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False)
+
+
+    def process(self, prompt):
+        logger.debug("infering language model...")
+
+        self.chat.append({"role": self.user_role, "content": prompt})
+        prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True)
+        output = ""
+        curr_output = ""
+        for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]):
+            output += t
+            curr_output += t
+            if curr_output.endswith(('.', '?', '!', '<|end|>')):
+                yield curr_output.replace('<|end|>', '')
+                curr_output = ""
+        generated_text = output.replace('<|end|>', '')
+        torch.mps.empty_cache()
+
+        self.chat.append({"role": "assistant", "content": generated_text})
diff --git a/requirements.txt b/requirements.txt
index b928d18..1a83958 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,5 @@ nltk==3.8.1
 parler_tts @ git+https://github.com/huggingface/parler-tts.git
 torch==2.4.0
 sounddevice==0.5.0
-lightning-whisper-mlx==0.0.10
\ No newline at end of file
+lightning-whisper-mlx==0.0.10
+mlx-lm==0.17.0
\ No newline at end of file
diff --git a/s2s_pipeline.py b/s2s_pipeline.py
index f9692e7..146f175 100644
--- a/s2s_pipeline.py
+++ b/s2s_pipeline.py
@@ -11,6 +11,7 @@ from threading import Event, Thread
 from time import perf_counter
 from typing import Optional
 
+from LLM.mlx_lm import MLXLanguageModelHandler
 from TTS.melotts import MeloTTSHandler
 from baseHandler import BaseHandler
 from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
@@ -484,13 +485,13 @@ class LanguageModelHandlerArguments:
         },
     )
     init_chat_role: str = field(
-        default=None,
+        default='system',
         metadata={
             "help": "Initial role for setting up the chat context. Default is 'system'."
         },
     )
     init_chat_prompt: str = field(
-        default="You are a helpful AI assistant.",
+        default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
         metadata={
             "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
         },
@@ -514,7 +515,7 @@ class LanguageModelHandlerArguments:
         },
     )
     chat_size: int = field(
-        default=1,
+        default=2,
         metadata={
             "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
         },
@@ -1015,7 +1016,7 @@ def main():
         queue_out=text_prompt_queue,
         setup_kwargs=vars(whisper_stt_handler_kwargs),
     )
-    lm = LanguageModelHandler(
+    lm = MLXLanguageModelHandler(
         stop_event,
         queue_in=text_prompt_queue,
         queue_out=lm_response_queue,
-- 
GitLab