Skip to content
Snippets Groups Projects
minicpm.py 2.04 KiB
Newer Older
Colin Wang's avatar
Colin Wang committed
# Adapted from https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5
# This has support for MiniCPM V2 and V2.5, and V2.6
Colin Wang's avatar
Colin Wang committed

from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from PIL import Image
import torch

def generate_response(queries, model_path):
    # sdpa attn impl for v2.6, default for 2 and 2.5
    if "MiniCPM-V-2_6" in model_path:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa')
    else:
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
    model = model.eval().cuda()
Colin Wang's avatar
Colin Wang committed
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    for k in tqdm(queries):
        query = queries[k]['question']
        image = Image.open(queries[k]["figure_path"]).convert('RGB')
        if model_path.endswith("MiniCPM-V-2"):
            msgs = [{'role': 'user', 'content': query}]
            res, context, _ = model.chat(
                image=image,
                msgs=msgs,
                context=None,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        # for 2.5
        elif model_path.endswith("MiniCPM-Llama3-V-2_5"):
            msgs = [{'role': 'user', 'content': query}]
            res = model.chat(
                image=image,
                msgs=msgs,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        # for 2.6
        elif model_path.endswith("MiniCPM-V-2_6"):
            msgs = [{'role': 'user', 'content': [image, query]}]
            res = model.chat(
                image=None,
                msgs=msgs,
                tokenizer=tokenizer,
                sampling=False,
                temperature=0.0,
                top_p=1.0,
            )
        else:
            raise NotImplementedError(f"Model path {model_path} not supported") 
Colin Wang's avatar
Colin Wang committed
        queries[k]['response'] = res