From 091caa890c99421e49fdff7485bf08bdfc946a85 Mon Sep 17 00:00:00 2001
From: Colin Wang <zw1300@princeton.edu>
Date: Sun, 18 Aug 2024 18:18:17 -0400
Subject: [PATCH] Add support for domain-specific models: 1. ChartAssistant
 (`chartast.py`) 2. ChartInstruct (`chartinstruct.py`) 3. ChartLlama
 (`chartllama.py`) 4. CogAgent (`cogagent.py`) 5. DocOwl1.5 (`docowl15.py`) 6.
 TextMonkey (`textmonkey.py`) 7. TinyChart (`tinychart.py`) 8. UniChart
 (`unichart.py`) 9. UReader (`ureader.py`)

---
 src/generate_lib/chartast.py      |  98 ++++++++++++++++++
 src/generate_lib/chartinstruct.py |  46 +++++++++
 src/generate_lib/chartllama.py    | 160 ++++++++++++++++++++++++++++++
 src/generate_lib/cogagent.py      |  49 +++++++++
 src/generate_lib/docowl15.py      |  25 +++++
 src/generate_lib/textmonkey.py    |  82 +++++++++++++++
 src/generate_lib/tinychart.py     |  44 ++++++++
 src/generate_lib/unichart.py      |  38 +++++++
 src/generate_lib/ureader.py       |  55 ++++++++++
 9 files changed, 597 insertions(+)
 create mode 100644 src/generate_lib/chartast.py
 create mode 100644 src/generate_lib/chartinstruct.py
 create mode 100644 src/generate_lib/chartllama.py
 create mode 100644 src/generate_lib/cogagent.py
 create mode 100644 src/generate_lib/docowl15.py
 create mode 100644 src/generate_lib/textmonkey.py
 create mode 100644 src/generate_lib/tinychart.py
 create mode 100644 src/generate_lib/unichart.py
 create mode 100644 src/generate_lib/ureader.py

diff --git a/src/generate_lib/chartast.py b/src/generate_lib/chartast.py
new file mode 100644
index 0000000..46709b2
--- /dev/null
+++ b/src/generate_lib/chartast.py
@@ -0,0 +1,98 @@
+# Adapted from https://github.com/OpenGVLab/ChartAst/blob/main/accessory/single_turn_eval.py
+# This has support for the ChartAssistant model
+
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/ChartAst/accessory')
+
+os.environ['MP'] = '1'
+os.environ['WORLD_SIZE'] = '1'
+
+import torch
+from tqdm import tqdm
+import torch.distributed as dist
+
+
+sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
+from fairscale.nn.model_parallel import initialize as fs_init
+from model.meta import MetaModel
+from util.tensor_parallel import load_tensor_parallel_model_list
+from util.misc import init_distributed_mode
+from PIL import Image
+
+import torchvision.transforms as transforms
+
+try:
+    from torchvision.transforms import InterpolationMode
+
+    BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+    BICUBIC = Image.BICUBIC
+
+from PIL import Image
+import os
+import torch
+
+
+class PadToSquare:
+    def __init__(self, background_color):
+        """
+        pad an image to squre (borrowed from LLAVA, thx)
+        :param background_color: rgb values for padded pixels, normalized to [0, 1]
+        """
+        self.bg_color = tuple(int(x * 255) for x in background_color)
+
+    def __call__(self, img: Image.Image):
+        width, height = img.size
+        if width == height:
+            return img
+        elif width > height:
+            result = Image.new(img.mode, (width, width), self.bg_color)
+            result.paste(img, (0, (width - height) // 2))
+            return result
+        else:
+            result = Image.new(img.mode, (height, height), self.bg_color)
+            result.paste(img, ((height - width) // 2, 0))
+            return result
+
+def T_padded_resize(size=448):
+    t = transforms.Compose([
+        PadToSquare(background_color=(0.48145466, 0.4578275, 0.40821073)),
+        transforms.Resize(
+            size, interpolation=transforms.InterpolationMode.BICUBIC
+        ),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+    return t
+
+def generate_response(queries, model_path):
+    init_distributed_mode()
+    fs_init.initialize_model_parallel(dist.get_world_size())
+    model = MetaModel('llama_ens5', model_path + '/params.json', model_path + '/tokenizer.model', with_visual=True)
+    print(f"load pretrained from {model_path}")
+    load_tensor_parallel_model_list(model, model_path)
+    model.bfloat16().cuda()
+    max_gen_len = 512
+    gen_t = 0.9
+    top_p = 0.5
+
+    for k in tqdm(queries):
+        question = queries[k]['question']
+        img_path = queries[k]['figure_path']
+
+        prompt = f"""Below is an instruction that describes a task. "
+                        "Write a response that appropriately completes the request.\n\n"
+                        "### Instruction:\nPlease answer my question based on the chart: {question}\n\n### Response:"""
+
+        image = Image.open(img_path).convert('RGB')
+        transform_val = T_padded_resize(448)
+        image = transform_val(image).unsqueeze(0)
+        image = image.cuda()
+
+        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+            response = model.generate([prompt], image, max_gen_len=max_gen_len, temperature=gen_t, top_p=top_p)
+        response = response[0].split('###')[0]
+        print(response)
+        queries[k]['response'] = response
diff --git a/src/generate_lib/chartinstruct.py b/src/generate_lib/chartinstruct.py
new file mode 100644
index 0000000..546aa4d
--- /dev/null
+++ b/src/generate_lib/chartinstruct.py
@@ -0,0 +1,46 @@
+# Adapted from https://huggingface.co/ahmed-masry/ChartInstruct-LLama2, https://huggingface.co/ahmed-masry/ChartInstruct-FlanT5-XL
+# This has support for two ChartInstruct models, LLama2 and FlanT5
+
+from PIL import Image
+from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoModelForSeq2SeqLM
+import torch
+from tqdm import tqdm
+
+def generate_response(queries, model_path):
+    if "LLama2" in model_path:
+        print("Using LLama2 model")
+        model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
+    elif "FlanT5" in model_path:
+        print("Using FlanT5 model")
+        model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
+    else:
+        raise ValueError(f"Model {model_path} not supported")
+    processor = AutoProcessor.from_pretrained(model_path)
+
+
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    model.to(device)
+
+    for k in tqdm(queries):
+        image_path = queries[k]['figure_path']
+        input_prompt = queries[k]['question']
+        input_prompt = f"<image>\n Question: {input_prompt} Answer: "
+
+        image = Image.open(image_path).convert('RGB')
+        inputs = processor(text=input_prompt, images=image, return_tensors="pt")
+        inputs = {k: v.to(device) for k, v in inputs.items()}
+
+        # change type if pixel_values in inputs to fp16. 
+        inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)
+        if "LLama2" in model_path:
+            prompt_length = inputs['input_ids'].shape[1]
+        
+        # move to device
+        inputs = {k: v.to(device) for k, v in inputs.items()}
+
+        # Generate
+        generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
+        output_text = processor.batch_decode(generate_ids[:, prompt_length:] \
+            if 'LLama2' in model_path else generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        print(output_text)
+        queries[k]['response'] = output_text
diff --git a/src/generate_lib/chartllama.py b/src/generate_lib/chartllama.py
new file mode 100644
index 0000000..8d5eb2f
--- /dev/null
+++ b/src/generate_lib/chartllama.py
@@ -0,0 +1,160 @@
+# Adapted from https://github.com/tingxueronghua/ChartLlama-code/blob/main/model_vqa_lora.py
+# This has support for the Chartllama model
+
+### HEADER START ###
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/ChartLlama-code')
+### HEADER END ###
+
+import argparse
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model import *
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
+from torch.utils.data import Dataset, DataLoader
+
+from PIL import Image
+import math
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
+    kwargs = {"device_map": device_map}
+
+    if load_8bit:
+        kwargs['load_in_8bit'] = True
+    elif load_4bit:
+        kwargs['load_in_4bit'] = True
+        kwargs['quantization_config'] = BitsAndBytesConfig(
+            load_in_4bit=True,
+            bnb_4bit_compute_dtype=torch.float16,
+            bnb_4bit_use_double_quant=True,
+            bnb_4bit_quant_type='nf4'
+        )
+    else:
+        kwargs['torch_dtype'] = torch.float16
+
+    # Load LLaVA model
+    if model_base is None:
+        raise ValueError('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
+    if model_base is not None:
+        lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
+        tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+        print('Loading LLaVA from base model...')
+        model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
+        token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+        if model.lm_head.weight.shape[0] != token_num:
+            model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+            model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+        print('Loading additional LLaVA weights...')
+        if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
+            non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
+        else:
+            # this is probably from HF Hub
+            from huggingface_hub import hf_hub_download
+            def load_from_hf(repo_id, filename, subfolder=None):
+                cache_file = hf_hub_download(
+                    repo_id=repo_id,
+                    filename=filename,
+                    subfolder=subfolder)
+                return torch.load(cache_file, map_location='cpu')
+            non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
+        non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
+        if any(k.startswith('model.model.') for k in non_lora_trainables):
+            non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+        model.load_state_dict(non_lora_trainables, strict=False)
+
+        from peft import PeftModel
+        print('Loading LoRA weights...')
+        model = PeftModel.from_pretrained(model, model_path)
+        print('Merging LoRA weights...')
+        model = model.merge_and_unload()
+        print('Model is loaded...')
+
+    image_processor = None
+
+    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+    if mm_use_im_patch_token:
+        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+    if mm_use_im_start_end:
+        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+    model.resize_token_embeddings(len(tokenizer))
+
+    vision_tower = model.get_vision_tower()
+    if not vision_tower.is_loaded:
+        vision_tower.load_model()
+    vision_tower.to(device=device, dtype=torch.float16)
+    image_processor = vision_tower.image_processor
+
+    if hasattr(model.config, "max_sequence_length"):
+        context_len = model.config.max_sequence_length
+    else:
+        context_len = 2048
+
+    return tokenizer, model, image_processor, context_len
+
+
+def generate_response(queries, model_path):
+    disable_torch_init()
+    base_model_path, model_path= model_path.split('::')
+    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, base_model_path, None)
+    conv_mode = "vicuna_v1"
+
+    def process(image, question, tokenizer, image_processor, model_config):
+        qs = question.replace(DEFAULT_IMAGE_TOKEN, '').strip()
+        if model.config.mm_use_im_start_end:
+            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
+        else:
+            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
+
+        conv = conv_templates[conv_mode].copy()
+        conv.append_message(conv.roles[0], qs)
+        conv.append_message(conv.roles[1], None)
+        prompt = conv.get_prompt()
+
+        image_tensor = process_images([image], image_processor, model_config)[0]
+
+        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
+    
+        return input_ids, image_tensor
+
+    for k in tqdm(queries):
+        image_path = queries[k]['figure_path']
+        image = Image.open(image_path).convert('RGB')
+        question = queries[k]['question']
+
+        input_ids, image_tensor = process(image, question, tokenizer, image_processor, model.config)
+        stop_str = conv_templates[conv_mode].sep if conv_templates[conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[conv_mode].sep2
+        input_ids = input_ids.to(device='cuda', non_blocking=True).unsqueeze(0) # added the unsqueeze(0) to make it batch size 1
+        image_tensor = image_tensor.unsqueeze(0) # added the unsqueeze(0) to make it batch size 1
+        with torch.inference_mode():
+            output_ids = model.generate(
+                input_ids,
+                images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
+                do_sample=False,
+                max_new_tokens=1636,
+                use_cache=True
+            )
+        input_token_len = input_ids.shape[1]
+        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
+        if n_diff_input_output > 0:
+            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
+        outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
+        outputs = outputs.strip()
+        if outputs.endswith(stop_str):
+            outputs = outputs[:-len(stop_str)]
+        outputs = outputs.strip()
+        queries[k]['response'] = outputs
diff --git a/src/generate_lib/cogagent.py b/src/generate_lib/cogagent.py
new file mode 100644
index 0000000..8279857
--- /dev/null
+++ b/src/generate_lib/cogagent.py
@@ -0,0 +1,49 @@
+# Adapted from https://huggingface.co/THUDM/cogagent-vqa-hf
+# This has support for the CogAgent model
+
+import torch
+from PIL import Image
+from transformers import AutoModelForCausalLM, LlamaTokenizer
+from tqdm import tqdm
+
+def generate_response(queries, model_path):
+    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+    torch_type = torch.bfloat16
+    tokenizer_path, model_path = model_path.split('::')
+    tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
+    model = AutoModelForCausalLM.from_pretrained(
+            model_path,
+            torch_dtype=torch.bfloat16,
+            low_cpu_mem_usage=True,
+            load_in_4bit=False,
+            trust_remote_code=True
+        ).to('cuda').eval()
+        
+    for k in tqdm(queries):
+        image_path = queries[k]['figure_path']
+        image = Image.open(image_path).convert('RGB')
+        query = f"Human:{queries[k]['question']}"
+        history = []
+    
+        input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
+        inputs = {
+            'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
+            'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
+            'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
+            'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]],
+        }
+        if 'cross_images' in input_by_model and input_by_model['cross_images']:
+            inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
+
+        # add any transformers params here.
+        gen_kwargs = {"max_length": 2048,
+                    "temperature": 0.9,
+                    "do_sample": False}
+        with torch.no_grad():
+            outputs = model.generate(**inputs, **gen_kwargs)
+            outputs = outputs[:, inputs['input_ids'].shape[1]:]
+            response = tokenizer.decode(outputs[0])
+            response = response.split("</s>")[0]
+            print("\nCog:", response)
+        print('model_answer:', response)
+        queries[k]['response'] = response
diff --git a/src/generate_lib/docowl15.py b/src/generate_lib/docowl15.py
new file mode 100644
index 0000000..01d39e9
--- /dev/null
+++ b/src/generate_lib/docowl15.py
@@ -0,0 +1,25 @@
+# Adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/DocOwl1.5/docowl_infer.py
+# This has support for the DocOwl model
+
+### HEADER START ###
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/mPLUG-DocOwl/DocOwl1.5')
+### HEADER END ###
+
+from docowl_infer import DocOwlInfer
+from tqdm import tqdm
+import os
+
+def generate_response(queries, model_path):
+    docowl = DocOwlInfer(ckpt_path=model_path, anchors='grid_9', add_global_img=True)
+    print('load model from ', model_path)
+    # infer the test samples one by one
+    for k in tqdm(queries):
+        image = queries[k]['figure_path']
+        question = queries[k]['question']
+        model_answer = docowl.inference(image, question)
+        print('model_answer:', model_answer)
+        queries[k]['response'] = model_answer
diff --git a/src/generate_lib/textmonkey.py b/src/generate_lib/textmonkey.py
new file mode 100644
index 0000000..9f449c6
--- /dev/null
+++ b/src/generate_lib/textmonkey.py
@@ -0,0 +1,82 @@
+# Adapted from https://github.com/Yuliang-Liu/Monkey/blob/main/demo_textmonkey.py
+# This has support for the TextMonkey model
+
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/Monkey')
+
+import re
+import gradio as gr
+from PIL import Image, ImageDraw, ImageFont
+from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel
+from monkey_model.tokenization_qwen import QWenTokenizer
+from monkey_model.configuration_monkey import MonkeyConfig
+from tqdm import tqdm
+
+def generate_response(queries, model_path):
+    device_map = "cuda"
+    # Create model
+    config = MonkeyConfig.from_pretrained(
+            model_path,
+            trust_remote_code=True,
+        )
+    model = TextMonkeyLMHeadModel.from_pretrained(model_path,
+        config=config,
+        device_map=device_map, trust_remote_code=True).eval()
+    tokenizer = QWenTokenizer.from_pretrained(model_path,
+                                                trust_remote_code=True)
+    tokenizer.padding_side = 'left'
+    tokenizer.pad_token_id = tokenizer.eod_id
+    tokenizer.IMG_TOKEN_SPAN = config.visual["n_queries"]
+
+    for k in tqdm(queries):
+        input_image = queries[k]['figure_path']
+        input_str = queries[k]['question']
+        input_str = f"<img>{input_image}</img> {input_str}"
+        input_ids = tokenizer(input_str, return_tensors='pt', padding='longest')
+
+        attention_mask = input_ids.attention_mask
+        input_ids = input_ids.input_ids
+        
+        pred = model.generate(
+        input_ids=input_ids.cuda(),
+        attention_mask=attention_mask.cuda(),
+        do_sample=False,
+        num_beams=1,
+        max_new_tokens=2048,
+        min_new_tokens=1,
+        length_penalty=1,
+        num_return_sequences=1,
+        output_hidden_states=True,
+        use_cache=True,
+        pad_token_id=tokenizer.eod_id,
+        eos_token_id=tokenizer.eod_id,
+        )
+        response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=False).strip()
+        image = Image.open(input_image).convert("RGB").resize((1000,1000))
+        font = ImageFont.truetype('NimbusRoman-Regular.otf', 22)
+        bboxes = re.findall(r'<box>(.*?)</box>', response, re.DOTALL)
+        refs = re.findall(r'<ref>(.*?)</ref>', response, re.DOTALL)
+        if len(refs)!=0:
+            num = min(len(bboxes), len(refs))
+        else:
+            num = len(bboxes)
+        for box_id in range(num):
+            bbox = bboxes[box_id]
+            matches = re.findall( r"\((\d+),(\d+)\)", bbox)
+            draw = ImageDraw.Draw(image)
+            point_x = (int(matches[0][0])+int(matches[1][0]))/2
+            point_y = (int(matches[0][1])+int(matches[1][1]))/2
+            point_size = 8
+            point_bbox = (point_x - point_size, point_y - point_size, point_x + point_size, point_y + point_size)
+            draw.ellipse(point_bbox, fill=(255, 0, 0))
+            if len(refs)!=0:
+                text = refs[box_id]
+                text_width, text_height = font.getsize(text)
+                draw.text((point_x-text_width//2, point_y+8), text, font=font, fill=(255, 0, 0))
+        response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+        output_str = response
+        print(f"Answer: {output_str}")
+        queries[k]['response'] = output_str
diff --git a/src/generate_lib/tinychart.py b/src/generate_lib/tinychart.py
new file mode 100644
index 0000000..10677aa
--- /dev/null
+++ b/src/generate_lib/tinychart.py
@@ -0,0 +1,44 @@
+# Adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/TinyChart/inference.ipynb
+# This has support for the TinyChart model
+
+### HEADER START ###
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/mPLUG-DocOwl/TinyChart')
+### HEADER END ###
+
+from tqdm import tqdm
+import torch
+from PIL import Image
+from tinychart.model.builder import load_pretrained_model
+from tinychart.mm_utils import get_model_name_from_path
+from tinychart.eval.run_tiny_chart import inference_model
+from tinychart.eval.eval_metric import parse_model_output, evaluate_cmds
+
+def generate_response(queries, model_path):
+    tokenizer, model, image_processor, context_len = load_pretrained_model(
+        model_path, 
+        model_base=None,
+        model_name=get_model_name_from_path(model_path),
+        device="cuda" # device="cpu" if running on cpu
+    )
+    for k in tqdm(queries):
+        img_path = queries[k]['figure_path']
+        text = queries[k]['question']
+        response = inference_model([img_path], text, model, tokenizer, image_processor, context_len, conv_mode="phi", max_new_tokens=1024)
+        # print(response)
+        try:
+            response = evaluate_cmds(parse_model_output(response))
+            print('Command successfully executed')
+            print(response)
+        except Exception as e:
+            # if message is NameError: name 'Answer' is not defined, then skip
+            if "Error: name 'Answer' is not defined" in str(e):
+                response = response
+            else:
+                print('Error:', e)
+                response = response
+        response = str(response)
+        queries[k]['response'] = response
diff --git a/src/generate_lib/unichart.py b/src/generate_lib/unichart.py
new file mode 100644
index 0000000..ed1d606
--- /dev/null
+++ b/src/generate_lib/unichart.py
@@ -0,0 +1,38 @@
+# Adapted from https://github.com/vis-nlp/UniChart/blob/main/README.md
+# This has support for the UniChart model
+
+from transformers import DonutProcessor, VisionEncoderDecoderModel
+from PIL import Image
+import torch
+from tqdm import tqdm
+
+def generate_response(queries, model_path):
+    model = VisionEncoderDecoderModel.from_pretrained(model_path)
+    processor = DonutProcessor.from_pretrained(model_path)
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+
+    for k in tqdm(queries):
+        image_path = queries[k]['figure_path']
+        input_prompt = queries[k]['question']
+        input_prompt = f"<chartqa> {input_prompt} <s_answer>"
+        image = Image.open(image_path).convert("RGB")
+        decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
+        pixel_values = processor(image, return_tensors="pt").pixel_values
+
+        outputs = model.generate(
+            pixel_values.to(device),
+            decoder_input_ids=decoder_input_ids.to(device),
+            max_length=model.decoder.config.max_position_embeddings,
+            early_stopping=True,
+            pad_token_id=processor.tokenizer.pad_token_id,
+            eos_token_id=processor.tokenizer.eos_token_id,
+            use_cache=True,
+            num_beams=4,
+            bad_words_ids=[[processor.tokenizer.unk_token_id]],
+            return_dict_in_generate=True,
+        )
+        sequence = processor.batch_decode(outputs.sequences)[0]
+        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+        sequence = sequence.split("<s_answer>")[1].strip()
+        queries[k]['response'] = sequence
diff --git a/src/generate_lib/ureader.py b/src/generate_lib/ureader.py
new file mode 100644
index 0000000..879fd15
--- /dev/null
+++ b/src/generate_lib/ureader.py
@@ -0,0 +1,55 @@
+# adapted from https://github.com/X-PLUG/mPLUG-DocOwl/blob/main/UReader/pipeline/interface.py
+# This has support for the UReader model
+
+### HEADER START ###
+import os
+vlm_codebase = os.environ['VLM_CODEBASE_DIR']
+
+import sys
+sys.path.append(vlm_codebase + '/mPLUG-DocOwl/UReader')
+
+UREADER_DIR = os.path.join(vlm_codebase, 'mPLUG-DocOwl/UReader/')
+### HEADER END ###
+
+import os
+import torch
+from sconf import Config
+from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
+from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
+from mplug_owl.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
+import torch
+from pipeline.data_utils.processors.builder import build_processors
+from pipeline.data_utils.processors import *
+from pipeline.utils import add_config_args, set_args
+import argparse
+
+from PIL import Image
+from tqdm import tqdm
+
+def generate_response(queries, model_path):
+    config = Config("{}configs/sft/release.yaml".format(UREADER_DIR))
+    args = argparse.ArgumentParser().parse_args([])
+    add_config_args(config, args)
+    set_args(args)
+    model = MplugOwlForConditionalGeneration.from_pretrained(
+        model_path,
+    )
+    model.eval()
+    model.cuda()
+    model.half()
+    image_processor = build_processors(config['valid_processors'])['sft']
+    tokenizer = MplugOwlTokenizer.from_pretrained(model_path)
+    processor = MplugOwlProcessor(image_processor, tokenizer)
+
+    for k in tqdm(queries):
+        image_path = queries[k]['figure_path']
+        images = [Image.open(image_path).convert('RGB')]
+        question = f"Human: <image>\nHuman: {queries[k]['question']}\nAI: "
+        inputs = processor(text=question, images=images, return_tensors='pt')
+        inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
+        inputs = {k: v.to(model.device) for k, v in inputs.items()}
+        with torch.no_grad():
+            res = model.generate(**inputs)
+        sentence = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
+        print('model_answer:', sentence)
+        queries[k]['response'] = sentence
-- 
GitLab