From 6d225db59213d2e7383f1e1017953e384d19cfd5 Mon Sep 17 00:00:00 2001
From: Colin Wang <zw1300@princeton.edu>
Date: Mon, 5 Aug 2024 17:58:45 -0400
Subject: [PATCH] 1. add support for internvl2 pro 2. add support for gemini
 1.5 0801 exp 3. fix generate fn arg inconsistency

---
 src/generate.py                  |  2 +-
 src/generate_lib/cambrian.py     |  2 +-
 src/generate_lib/chartgemma.py   |  2 +-
 src/generate_lib/deepseekvl.py   |  2 +-
 src/generate_lib/idefics2.py     |  2 +-
 src/generate_lib/internvl15.py   |  2 +-
 src/generate_lib/internvl2.py    |  2 +-
 src/generate_lib/internvl2pro.py | 38 ++++++++++++++++++++++++++++++++
 src/generate_lib/ixc2.py         |  2 +-
 src/generate_lib/llava16.py      |  2 +-
 src/generate_lib/mgm.py          |  2 +-
 src/generate_lib/minicpm.py      |  2 +-
 src/generate_lib/moai.py         |  2 +-
 src/generate_lib/ovis.py         |  2 +-
 src/generate_lib/paligemma.py    |  2 +-
 src/generate_lib/phi3.py         |  2 +-
 src/generate_lib/sphinx2.py      |  2 +-
 src/generate_lib/utils.py        | 12 ++++++++--
 src/generate_lib/vila15.py       |  2 +-
 19 files changed, 65 insertions(+), 19 deletions(-)
 create mode 100644 src/generate_lib/internvl2pro.py

diff --git a/src/generate.py b/src/generate.py
index e43fd11..a5927e3 100644
--- a/src/generate.py
+++ b/src/generate.py
@@ -51,7 +51,7 @@ if __name__ == '__main__':
         client, model = get_client_fn(args.model_path)(args.model_path, args.model_api)
         generate_response_remote_wrapper(generate_fn, queries, model, args.model_api, client)
     else:
-        generate_fn(args.model_path, queries)
+        generate_fn(queries, args.model_path)
 
     for k in queries:
         queries[k].pop("figure_path", None)
diff --git a/src/generate_lib/cambrian.py b/src/generate_lib/cambrian.py
index 416c081..0a5c0c1 100644
--- a/src/generate_lib/cambrian.py
+++ b/src/generate_lib/cambrian.py
@@ -19,7 +19,7 @@ from cambrian.mm_utils import tokenizer_image_token, process_images, get_model_n
 
 from PIL import Image
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     conv_mode = "chatml_direct"
     def process(image, question, tokenizer, image_processor, model_config):
         qs = question
diff --git a/src/generate_lib/chartgemma.py b/src/generate_lib/chartgemma.py
index 4337db5..c5a204d 100644
--- a/src/generate_lib/chartgemma.py
+++ b/src/generate_lib/chartgemma.py
@@ -6,7 +6,7 @@ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
 import torch
 from tqdm import tqdm
 
-def generate_response(queries, model_path=None):
+def generate_response(queries, model_path):
     # Load Model
     model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
     processor = AutoProcessor.from_pretrained(model_path)
diff --git a/src/generate_lib/deepseekvl.py b/src/generate_lib/deepseekvl.py
index 11b9e85..d709c72 100644
--- a/src/generate_lib/deepseekvl.py
+++ b/src/generate_lib/deepseekvl.py
@@ -8,7 +8,7 @@ from deepseek_vl.utils.io import load_pil_images
 import torch
 from tqdm import tqdm
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     # specify the path to the model
     vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
     tokenizer = vl_chat_processor.tokenizer
diff --git a/src/generate_lib/idefics2.py b/src/generate_lib/idefics2.py
index 7b19fcd..d02ca64 100644
--- a/src/generate_lib/idefics2.py
+++ b/src/generate_lib/idefics2.py
@@ -4,7 +4,7 @@ from transformers.image_utils import load_image
 from transformers import AutoProcessor, AutoModelForVision2Seq
 from tqdm import tqdm
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     model = AutoModelForVision2Seq.from_pretrained(model_path).to('cuda')
     processor = AutoProcessor.from_pretrained(model_path)
     for k in tqdm(queries):
diff --git a/src/generate_lib/internvl15.py b/src/generate_lib/internvl15.py
index 2be19bc..d4469cc 100644
--- a/src/generate_lib/internvl15.py
+++ b/src/generate_lib/internvl15.py
@@ -87,7 +87,7 @@ def load_image(image_file, input_size=448, max_num=6):
     pixel_values = torch.stack(pixel_values)
     return pixel_values
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     model = AutoModel.from_pretrained(
         model_path,
         torch_dtype=torch.bfloat16,
diff --git a/src/generate_lib/internvl2.py b/src/generate_lib/internvl2.py
index cc1e7f4..d415da8 100644
--- a/src/generate_lib/internvl2.py
+++ b/src/generate_lib/internvl2.py
@@ -113,7 +113,7 @@ def split_model(model_name):
 
     return device_map
 
-def generate_response(queries, model_path=None):
+def generate_response(queries, model_path):
     device_map = split_model(model_path.split('/')[-1])
     print(device_map)
     model = AutoModel.from_pretrained(
diff --git a/src/generate_lib/internvl2pro.py b/src/generate_lib/internvl2pro.py
new file mode 100644
index 0000000..5c8aa65
--- /dev/null
+++ b/src/generate_lib/internvl2pro.py
@@ -0,0 +1,38 @@
+
+
+import requests
+
+def get_client_model(model_path, api_key):
+    assert api_key is not None, "API key is required for using GPT"
+    assert model_path is not None, "Model name is required for using GPT"
+    model = model_path
+    client = None
+    return client, model
+
+def generate_response(image_path, query, model, media_type="image/jpeg", api_key=None, client=None, random_baseline=False):
+
+    url = "http://101.132.98.120:11005/chat/" 
+
+
+    file_paths = [
+        image_path
+    ]
+    question = query 
+
+    files = [('files', open(file_path, 'rb')) for file_path in file_paths]
+    data = {
+        'question': question,
+        'api_key': api_key
+    }
+
+    try:
+        response = requests.post(url, files=files, data=data)
+        if response.status_code == 200:
+            print("Response:", response.json().get("response", "No response key found in the JSON."))
+            return response.json().get("response", "No response key found in the JSON.")
+        else:
+            print("Error:", response.status_code, response.text)
+            return "Error in generating response."
+    except requests.exceptions.RequestException as e:
+        print(f"Error: {e}")
+        return "Error in generating response."
diff --git a/src/generate_lib/ixc2.py b/src/generate_lib/ixc2.py
index a96476b..b44e5b5 100644
--- a/src/generate_lib/ixc2.py
+++ b/src/generate_lib/ixc2.py
@@ -4,7 +4,7 @@ import torch
 from transformers import AutoModel, AutoTokenizer
 from tqdm import tqdm
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     # taken from: 
     torch.set_grad_enabled(False)
     if '4khd' in model_path:
diff --git a/src/generate_lib/llava16.py b/src/generate_lib/llava16.py
index 2865832..cc4a30e 100644
--- a/src/generate_lib/llava16.py
+++ b/src/generate_lib/llava16.py
@@ -6,7 +6,7 @@ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
 from tqdm import tqdm
 from PIL import Image
 
-def generate_responses(model_path, queries):
+def generate_responses(queries, model_path):
     # taken from: https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf
     processor = LlavaNextProcessor.from_pretrained(model_path)
     model = LlavaNextForConditionalGeneration.from_pretrained(model_path, 
diff --git a/src/generate_lib/mgm.py b/src/generate_lib/mgm.py
index 2cfeb5e..c5b81f2 100644
--- a/src/generate_lib/mgm.py
+++ b/src/generate_lib/mgm.py
@@ -64,7 +64,7 @@ def get_image_input_from_path(image, model, image_processor):
     return images, images_aux, 
 
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     disable_torch_init()
     model_path = os.path.expanduser(model_path)
     model_name = get_model_name_from_path(model_path)
diff --git a/src/generate_lib/minicpm.py b/src/generate_lib/minicpm.py
index 6b859e2..80c1a04 100644
--- a/src/generate_lib/minicpm.py
+++ b/src/generate_lib/minicpm.py
@@ -6,7 +6,7 @@ from tqdm import tqdm
 from PIL import Image
 import torch
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
     model = model.to(device='cuda', dtype=torch.bfloat16)
     tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
diff --git a/src/generate_lib/moai.py b/src/generate_lib/moai.py
index 19f2ea3..cc8a3fd 100644
--- a/src/generate_lib/moai.py
+++ b/src/generate_lib/moai.py
@@ -16,7 +16,7 @@ import tqdm
 from PIL import Image
 import torch
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
     = prepare_moai(moai_path=model_path, bits=4, grad_ckpt=False, lora=False, dtype='fp16')
     for k in tqdm(queries):
diff --git a/src/generate_lib/ovis.py b/src/generate_lib/ovis.py
index 3b5ef82..5bdb500 100644
--- a/src/generate_lib/ovis.py
+++ b/src/generate_lib/ovis.py
@@ -6,7 +6,7 @@ from PIL import Image
 from transformers import AutoModelForCausalLM
 from tqdm import tqdm
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     model = AutoModelForCausalLM.from_pretrained(model_path,
                                                  torch_dtype=torch.bfloat16,
                                                  multimodal_max_length=8192,
diff --git a/src/generate_lib/paligemma.py b/src/generate_lib/paligemma.py
index 50ff0e8..b3ebee0 100644
--- a/src/generate_lib/paligemma.py
+++ b/src/generate_lib/paligemma.py
@@ -6,7 +6,7 @@ from PIL import Image
 import torch
 from tqdm import tqdm
 
-def generate_response(queries, model_path=None):
+def generate_response(queries, model_path):
     model_id = model_path
     device = "cuda:0"
     dtype = torch.bfloat16
diff --git a/src/generate_lib/phi3.py b/src/generate_lib/phi3.py
index 694e34e..0e959a2 100644
--- a/src/generate_lib/phi3.py
+++ b/src/generate_lib/phi3.py
@@ -5,7 +5,7 @@ from PIL import Image
 from transformers import AutoModelForCausalLM, AutoProcessor
 from tqdm import tqdm
 
-def generate_response(queries, model_path=None):
+def generate_response(queries, model_path):
     
 
     model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda", trust_remote_code=True, 
diff --git a/src/generate_lib/sphinx2.py b/src/generate_lib/sphinx2.py
index f69c25c..eeef793 100644
--- a/src/generate_lib/sphinx2.py
+++ b/src/generate_lib/sphinx2.py
@@ -5,7 +5,7 @@ from SPHINX import SPHINXModel
 from PIL import Image
 from tqdm import tqdm
 
-def generate_response(model_path, queries):
+def generate_response(queries, model_path):
     model = SPHINXModel.from_pretrained(pretrained_path=model_path, with_visual=True)
     for k in tqdm(queries):
         qas = [[queries[k]['question'], None]]
diff --git a/src/generate_lib/utils.py b/src/generate_lib/utils.py
index db4107a..94a0d33 100644
--- a/src/generate_lib/utils.py
+++ b/src/generate_lib/utils.py
@@ -33,7 +33,8 @@ def get_client_fn(model_path):
     # gemini
     elif model_path in ['gemini-1.5-pro-001', 
                         'gemini-1.0-pro-vision-001', 
-                        'gemini-1.5-flash-001']:
+                        'gemini-1.5-flash-001',
+                        'gemini-1.5-pro-exp-0801']:
         from .gemini import get_client_model
     # gpt
     elif model_path in ['gpt-4o-2024-05-13', 
@@ -49,6 +50,9 @@ def get_client_fn(model_path):
     elif model_path in ['qwen-vl-max', 
                         'qwen-vl-plus']:
         from .qwen import get_client_model
+    # internvl2pro
+    elif model_path in ['InternVL2-Pro']:
+        from .internvl2pro import get_client_model
     else:
         raise ValueError(f"Model {model_path} not supported")
     return get_client_model
@@ -73,7 +77,8 @@ def get_generate_fn(model_path):
     # gemini
     elif model_name in ['gemini-1.5-pro-001', 
                         'gemini-1.0-pro-vision-001', 
-                        'gemini-1.5-flash-001']:
+                        'gemini-1.5-flash-001',
+                        'gemini-1.5-pro-exp-0801']:
         from .gemini import generate_response
     # gpt
     elif model_name in ['gpt-4o-2024-05-13', 
@@ -135,6 +140,9 @@ def get_generate_fn(model_path):
     elif model_name in ['Ovis1.5-Llama3-8B',
                         'Ovis1.5-Gemma2-9B']:
         from .ovis import generate_response
+    # internvl2pro
+    elif model_name in ['InternVL2-Pro']:
+        from .internvl2pro import generate_response
     else:
         raise ValueError(f"Model {model_name} not supported")
     return generate_response
diff --git a/src/generate_lib/vila15.py b/src/generate_lib/vila15.py
index 0badfa0..223212f 100644
--- a/src/generate_lib/vila15.py
+++ b/src/generate_lib/vila15.py
@@ -34,7 +34,7 @@ def load_images(image_files):
         out.append(image)
     return out
 
-def generate_response(queries, model_path=None):
+def generate_response(queries, model_path):
     disable_torch_init()
     
     model_name = get_model_name_from_path(model_path)
-- 
GitLab