Skip to content
Snippets Groups Projects
Unverified Commit 3d98e34b authored by Ronan McGovern's avatar Ronan McGovern Committed by GitHub
Browse files

Allow LM selection and Gemma

Allow the language model to be selected via command line.
Remove the system message if using Gemma.
parent cb48b916
No related branches found
No related tags found
No related merge requests found
...@@ -30,8 +30,7 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -30,8 +30,7 @@ class MLXLanguageModelHandler(BaseHandler):
init_chat_prompt="You are a helpful AI assistant.", init_chat_prompt="You are a helpful AI assistant.",
): ):
self.model_name = model_name self.model_name = model_name
model_id = "microsoft/Phi-3-mini-4k-instruct" self.model, self.tokenizer = load(self.model_name)
self.model, self.tokenizer = load(model_id)
self.gen_kwargs = gen_kwargs self.gen_kwargs = gen_kwargs
self.chat = Chat(chat_size) self.chat = Chat(chat_size)
...@@ -67,8 +66,15 @@ class MLXLanguageModelHandler(BaseHandler): ...@@ -67,8 +66,15 @@ class MLXLanguageModelHandler(BaseHandler):
logger.debug("infering language model...") logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt}) self.chat.append({"role": self.user_role, "content": prompt})
# Remove system messages if using a Gemma model
if "gemma" in self.model_name.lower():
chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"]
else:
chat_messages = self.chat.to_list()
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
self.chat.to_list(), tokenize=False, add_generation_prompt=True chat_messages, tokenize=False, add_generation_prompt=True
) )
output = "" output = ""
curr_output = "" curr_output = ""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment