import requests
import torch
from PIL import Image
from tqdm import tqdm
from transformers import MllamaForConditionalGeneration, AutoProcessor

def generate_response(queries, model_path):
    model = MllamaForConditionalGeneration.from_pretrained(model_path,
                                                           torch_dtype=torch.bfloat16,
                                                           device_map="auto")
    processor = AutoProcessor.from_pretrained(model_path)

    for k in tqdm(queries):
        query = queries[k]['question']
        image = queries[k]["figure_path"]
        image = Image.open(image).convert('RGB')
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": query}
            ]}
        ]
        input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt"
        ).to(model.device)

        output = model.generate(**inputs, max_new_tokens=1024)
        response = processor.decode(output[0])
        response = response.split("<|start_header_id|>assistant<|end_header_id|>")[1].replace("<|eot_id|>", "").strip()
        queries[k]['response'] = response