From 6747cbc9560a9e17b3eb2c55b56ad3dc76931160 Mon Sep 17 00:00:00 2001
From: zwcolin <ziruiw2000@gmail.com>
Date: Mon, 19 Aug 2024 14:59:11 -0400
Subject: [PATCH] update authors' official implementation of minicpm-v2.6
 evaluation code

---
 src/generate_lib/minicpm.py | 31 ++++++++++++++++++++++++++-----
 1 file changed, 26 insertions(+), 5 deletions(-)

diff --git a/src/generate_lib/minicpm.py b/src/generate_lib/minicpm.py
index f5b45df..bbe7e60 100644
--- a/src/generate_lib/minicpm.py
+++ b/src/generate_lib/minicpm.py
@@ -1,12 +1,23 @@
 # Adapted from https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5
+# Part of V2.6 implementation is adapted directly from the authors
 # This has support for MiniCPM V2 and V2.5, and V2.6
 
 from transformers import AutoModel, AutoTokenizer
 from tqdm import tqdm
 from PIL import Image
 import torch
+import random
+import numpy as np
+import math
 
-def generate_response(queries, model_path):
+def generate_response(queries, model_path, use_cot=False, random_upsize=False, seed=0):
+    if use_cot or random_upsize:
+        assert "MiniCPM-V2_6" in model_path, "cot and upsize functionalities are provided by the paper's authors"
+    if random_upsize:
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
     # sdpa attn impl for v2.6, default for 2 and 2.5
     if "MiniCPM-V-2_6" in model_path:
         model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa')
@@ -40,16 +51,26 @@ def generate_response(queries, model_path):
                 temperature=0.0,
                 top_p=1.0,
             )
-        # for 2.6
+        # for 2.6 (code is adapted from authors directly)
         elif model_path.endswith("MiniCPM-V-2_6"):
-            msgs = [{'role': 'user', 'content': [image, query]}]
+            if random_upsize:
+                img_width, img_height = image.width, image.height
+                if (img_width * img_height) < (1344 * 1344):
+                    ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
+                    max_img_width = int(img_width * ratio)
+                    new_img_width = random.randint(img_width, max_img_width)
+                    new_img_height = int(new_img_width / img_width * img_height)
+                    image = image.resize((new_img_width, new_img_height))
+            system_cot_prompt = '''Based on the following image, please first give your understanding of the following question, then perform careful reasoning, and finally give the final answer.'''
+            msgs = [{'role': 'user', 'content': [image, query] if not use_cot else [system_cot_prompt, image, query]}]
             res = model.chat(
                 image=None,
                 msgs=msgs,
                 tokenizer=tokenizer,
+                max_inp_length=8192,
                 sampling=False,
-                temperature=0.0,
-                top_p=1.0,
+                max_new_tokens=2048,
+                num_beams=3
             )
         else:
             raise NotImplementedError(f"Model path {model_path} not supported") 
-- 
GitLab