diff --git a/LLM/mlx_lm.py b/LLM/mlx_lm.py index a772e3a34a727c3656f87bff532f012bc915c43a..0192afb36e0fcfe5a166d7589a9fe6f38e8cfd05 100644 --- a/LLM/mlx_lm.py +++ b/LLM/mlx_lm.py @@ -30,8 +30,7 @@ class MLXLanguageModelHandler(BaseHandler): init_chat_prompt="You are a helpful AI assistant.", ): self.model_name = model_name - model_id = "microsoft/Phi-3-mini-4k-instruct" - self.model, self.tokenizer = load(model_id) + self.model, self.tokenizer = load(self.model_name) self.gen_kwargs = gen_kwargs self.chat = Chat(chat_size) @@ -67,8 +66,15 @@ class MLXLanguageModelHandler(BaseHandler): logger.debug("infering language model...") self.chat.append({"role": self.user_role, "content": prompt}) + + # Remove system messages if using a Gemma model + if "gemma" in self.model_name.lower(): + chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"] + else: + chat_messages = self.chat.to_list() + prompt = self.tokenizer.apply_chat_template( - self.chat.to_list(), tokenize=False, add_generation_prompt=True + chat_messages, tokenize=False, add_generation_prompt=True ) output = "" curr_output = ""