# Adapted from https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5
# This has support for MiniCPM V2 and V2.5, and V2.6

from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from PIL import Image
import torch

def generate_response(queries, model_path):
    # sdpa attn impl for v2.6, default for 2 and 2.5
    if "MiniCPM-V-2_6" in model_path:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa')
    else:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
    model = model.eval().cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    for k in tqdm(queries):
        query = queries[k]['question']
        image = Image.open(queries[k]["figure_path"]).convert('RGB')
        if model_path.endswith("MiniCPM-V-2"):
            msgs = [{'role': 'user', 'content': query}]
            res, context, _ = model.chat(
                image=image,
                msgs=msgs,
                context=None,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        # for 2.5
        elif model_path.endswith("MiniCPM-Llama3-V-2_5"):
            msgs = [{'role': 'user', 'content': query}]
            res = model.chat(
                image=image,
                msgs=msgs,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        # for 2.6
        elif model_path.endswith("MiniCPM-V-2_6"):
            msgs = [{'role': 'user', 'content': [image, query]}]
            res = model.chat(
                image=None,
                msgs=msgs,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        else:
            raise NotImplementedError(f"Model path {model_path} not supported") 
        queries[k]['response'] = res