From ff5882f00665fedcecacb8da13ead4d73079d9e4 Mon Sep 17 00:00:00 2001
From: bys0318 <bys0318@126.com>
Date: Thu, 2 Nov 2023 20:24:36 +0800
Subject: [PATCH] fix chatglm3 build_chat

---
 pred.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/pred.py b/pred.py
index 8163891..e3ff006 100644
--- a/pred.py
+++ b/pred.py
@@ -18,7 +18,8 @@ def parse_args(args=None):
 # This is the customized building prompt for chat models
 def build_chat(tokenizer, prompt, model_name):
     if "chatglm3" in model_name:
-        prompt = tokenizer.build_chat_input(prompt)
+        prompt = tokenizer.build_chat_input(prompt).input_ids[0]
+        prompt = tokenizer.decode(prompt[2:], skip_special_tokens=True)
     elif "chatglm" in model_name:
         prompt = tokenizer.build_prompt(prompt)
     elif "longchat" in model_name or "vicuna" in model_name:
@@ -26,7 +27,7 @@ def build_chat(tokenizer, prompt, model_name):
         conv = get_conversation_template("vicuna")
         conv.append_message(conv.roles[0], prompt)
         conv.append_message(conv.roles[1], None)
-        prompt = conv.get_prompt()        
+        prompt = conv.get_prompt()
     elif "llama2" in model_name:
         prompt = f"[INST]{prompt}[/INST]"
     elif "xgen" in model_name:
@@ -52,6 +53,8 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
         prompt = prompt_format.format(**json_obj)
         # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
         tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
+        if "chatglm3" in model:
+            tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
         if len(tokenized_prompt) > max_length:
             half = int(max_length/2)
             prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
-- 
GitLab