From 1bf46dd5b0acbc06817126a63c7fad913a351cd8 Mon Sep 17 00:00:00 2001
From: root <running.lsy@alibaba-inc.com>
Date: Wed, 31 Jul 2024 09:29:22 +0000
Subject: [PATCH] add support for ovis series

---
 src/generate_lib/ovis.py  | 44 +++++++++++++++++++++++++++++++++++++++
 src/generate_lib/utils.py |  4 ++++
 2 files changed, 48 insertions(+)
 create mode 100644 src/generate_lib/ovis.py

diff --git a/src/generate_lib/ovis.py b/src/generate_lib/ovis.py
new file mode 100644
index 0000000..3b5ef82
--- /dev/null
+++ b/src/generate_lib/ovis.py
@@ -0,0 +1,44 @@
+# Adapted from https://huggingface.co/AIDC-AI/Ovis1.5-Llama3-8B
+# This has support for the Ovis model series
+
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM
+from tqdm import tqdm
+
+def generate_response(model_path, queries):
+    model = AutoModelForCausalLM.from_pretrained(model_path,
+                                                 torch_dtype=torch.bfloat16,
+                                                 multimodal_max_length=8192,
+                                                 trust_remote_code=True).cuda()
+    text_tokenizer = model.get_text_tokenizer()
+    visual_tokenizer = model.get_visual_tokenizer()
+    conversation_formatter = model.get_conversation_formatter()
+
+    for k in tqdm(queries):
+        query = queries[k]['question']
+        image = queries[k]["figure_path"]
+        image = Image.open(image).convert('RGB')
+        query = f'<image>\n{query}'
+        prompt, input_ids = conversation_formatter.format_query(query)
+        input_ids = torch.unsqueeze(input_ids, dim=0).to(device=model.device)
+        attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(device=model.device)
+        pixel_values = [visual_tokenizer.preprocess_image(image).to(
+            dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
+
+        with torch.inference_mode():
+            gen_kwargs = dict(
+                max_new_tokens=1024,
+                do_sample=False,
+                top_p=None,
+                top_k=None,
+                temperature=None,
+                repetition_penalty=None,
+                eos_token_id=model.generation_config.eos_token_id,
+                pad_token_id=text_tokenizer.pad_token_id,
+                use_cache=True
+            )
+            output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
+            response = text_tokenizer.decode(output_ids, skip_special_tokens=True)
+
+        queries[k]['response'] = response
diff --git a/src/generate_lib/utils.py b/src/generate_lib/utils.py
index baaa398..db4107a 100644
--- a/src/generate_lib/utils.py
+++ b/src/generate_lib/utils.py
@@ -131,6 +131,10 @@ def get_generate_fn(model_path):
     # vila
     elif model_name in ['VILA1.5-40b']:
         from .vila15 import generate_response
+    # ovis
+    elif model_name in ['Ovis1.5-Llama3-8B',
+                        'Ovis1.5-Gemma2-9B']:
+        from .ovis import generate_response
     else:
         raise ValueError(f"Model {model_name} not supported")
     return generate_response
-- 
GitLab