From 3d98e34b4c7335dbd48afbe76eed6f4ae4cfad24 Mon Sep 17 00:00:00 2001
From: Ronan McGovern <78278410+RonanKMcGovern@users.noreply.github.com>
Date: Fri, 23 Aug 2024 11:57:59 +0100
Subject: [PATCH] Allow LM selection and Gemma

Allow the language model to be selected via command line.
Remove the system message if using Gemma.
---
 LLM/mlx_lm.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/LLM/mlx_lm.py b/LLM/mlx_lm.py
index a772e3a..0192afb 100644
--- a/LLM/mlx_lm.py
+++ b/LLM/mlx_lm.py
@@ -30,8 +30,7 @@ class MLXLanguageModelHandler(BaseHandler):
         init_chat_prompt="You are a helpful AI assistant.",
     ):
         self.model_name = model_name
-        model_id = "microsoft/Phi-3-mini-4k-instruct"
-        self.model, self.tokenizer = load(model_id)
+        self.model, self.tokenizer = load(self.model_name)
         self.gen_kwargs = gen_kwargs
 
         self.chat = Chat(chat_size)
@@ -67,8 +66,15 @@ class MLXLanguageModelHandler(BaseHandler):
         logger.debug("infering language model...")
 
         self.chat.append({"role": self.user_role, "content": prompt})
+        
+        # Remove system messages if using a Gemma model
+        if "gemma" in self.model_name.lower():
+            chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"]
+        else:
+            chat_messages = self.chat.to_list()
+        
         prompt = self.tokenizer.apply_chat_template(
-            self.chat.to_list(), tokenize=False, add_generation_prompt=True
+            chat_messages, tokenize=False, add_generation_prompt=True
         )
         output = ""
         curr_output = ""
-- 
GitLab