import os, json
from constants import DESCRIPTIVE_RESP_INST, DESCRIPTIVE_GRADING_PREFIX, \
            DESCRIPTIVE_GRADING_QMAP, DESCRIPTIVE_GRADING_ICL

def get_rubric(qid):
    instruction = None
    if qid in [1]:
        instruction = DESCRIPTIVE_GRADING_ICL['title']
    if qid in [2, 3, 4, 5, 6, 7]:
        instruction = DESCRIPTIVE_GRADING_ICL['ocr']
    if qid in [8, 9, 10, 12, 14, 15, 17, 19]:
        instruction = DESCRIPTIVE_GRADING_ICL['quant']
    if qid in [11]:
        instruction = DESCRIPTIVE_GRADING_ICL['bool']
    if qid in [13]:
        instruction = DESCRIPTIVE_GRADING_ICL['enum']
    if qid in [16]:
        instruction = DESCRIPTIVE_GRADING_ICL['trend']
    if qid in [18]:
        instruction = DESCRIPTIVE_GRADING_ICL['layout']
    assert instruction is not None, f"Instruction for qid {qid} is not found."
    return instruction

def get_descriptive_result_gpt(client, prompt, length, max_retries=10):
    curr_retries = 0
    max_tokens = 256
    while curr_retries < max_retries:
        try:
            response = client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": prompt,
                    }
                ],
                model="gpt-4o-2024-05-13",
                response_format={"type": "json_object"},
                n=1,
                max_tokens=max_tokens,
                temperature=0,
                top_p=1,
                seed=42,
            ).choices[0].message.content
            content = json.loads(response)
            verify_grading_output(content, length)
            break
        except Exception as e:
            print(f"Error: {e}")
            # increase the max_tokens if the response is too long
            if 'Unterminated string starting at' in str(e):
                if max_tokens >= 1024:
                    print(f"Failed to get response for prompt: {prompt}")
                    content = build_dummy_output(length)
                    break
                else:
                    max_tokens = min(1024, max_tokens * 2) # double the max_tokens
                    print(f"Retrying with max_tokens: {max_tokens}")
            # otherwise, retry the request
            curr_retries += 1
    # if failed to get response, return dummy data
    if curr_retries == max_retries:
        print(f"Failed to get response for prompt: {prompt}")
        content = build_dummy_output(length)
    return content

def build_json_keys(length):
    keys = []
    # specify the keys for gpt-4o's json response
    for i in range(1, length+1):
        keys.append(f"extract_answer_T{i}")
        keys.append(f"score_T{i}")
    return str(keys)

def populate_grading_inputs(batch):
    query = ""
    for i, (_, response, answer) in enumerate(batch):
        # index, response, answer
        curr_query = "T{}:\nResponse {}: {}\nGround Truth {}: {}\n\n"\
            .format(i+1, i+1, response, i+1, answer)
        query += curr_query
    return query

def verify_grading_output(data, length_data):
    # check the integrity of keys and values
    for i in range(1, length_data+1):
        assert f"extract_answer_T{i}" in data, f"extract_answer_T{i} is not found in {d}"
        assert f"score_T{i}" in data, f"score_T{i} is not found in {data}"
        assert data[f"score_T{i}"] in [0, 1], f"score_T{i} is not in [0, 1]"
    return True

def build_dummy_output(length_data):
    # if failed to parse the response, return dummy data
    data = {}
    for i in range(1, length_data+1):
        data[f"extract_answer_T{i}"] = "Failed to parse response"
        data[f"score_T{i}"] = -1
    return data

def preprocess_descriptive_grading_queries(input, resp, num_templates=19):
    # group the responses based on the template id instead of figure id
    groups = {i: [] for i in range(1, num_templates + 1)}
    for _, data in input.items():
        figure_id = data['figure_id']
        qids = data['qids']
        for i, qid in enumerate(qids):
            # figure_id with question index
            resp_key = f"{figure_id}_{i}"
            response = resp[resp_key]['response']
            answer = data['answers'][i]
            groups[qid].append((resp_key, response, answer))
    return groups

def build_descriptive_grading_queries(groups, nq_per_query=5):
    queries = []
    for qid, data in groups.items():
        # batched evaluation based on number of questions per query (nq_per_query)
        for i in range(0, len(data), nq_per_query):
            # batch: list of tuples (resp_key, response, answer)
            batch = data[i : i + nq_per_query]
            # question based on the template id
            question = DESCRIPTIVE_GRADING_QMAP[qid]
            # build the json keys for GPT-4o's response
            json_keys = build_json_keys(len(batch))
            # populate batch size, question, and json keys spec
            prefix = DESCRIPTIVE_GRADING_PREFIX\
                .replace("<|NUM_TRIPLETS|>", str(len(batch)))\
                .replace("<|OVERARCHING_QUESTION|>", question)\
                .replace("<|JSON_KEYS|>", json_keys)
            # add in-context grading example based on the template id
            rubric_icl = get_rubric(qid)
            # prompt + example + model responses
            grading_query = prefix + rubric_icl + populate_grading_inputs(batch)
            curr_query = {
                'resp_keys': [d[0] for d in batch],
                'grading_query': grading_query,
            }
            queries.append(curr_query)
    return queries

def postprocess_descriptive_grading_queries(queries):
    scores = {}
    for query in queries:
        # query contains resp_keys, grading_query, extract_answer and score
        resp_keys = query['resp_keys']
        for i, resp_key in enumerate(resp_keys):
            # extract the answer and score for each response key
            extracted_answer = query[f"extract_answer_T{i+1}"]
            score = query[f"score_T{i+1}"]
            # store the extracted answer and score
            scores[resp_key] = {
                'resp_id': resp_key,
                'extracted_answer': extracted_answer,
                'score': score,
            }
    return scores

def descriptive_query_helper(qid, subplot_loc):
    if qid in [18, 19]:
        # skip subplot location when asking about the layout of the subplots
        return DESCRIPTIVE_RESP_INST[qid] 
    if isinstance(subplot_loc, list):
        if subplot_loc[0] == 0:
            # when there is only one subplot
            prefix = "For the current plot, "
        else:
            # when there are multiple subplots
            prefix = f"For the subplot at row {subplot_loc[0]} and column {subplot_loc[1]}, "
    # when subplots do not form a grid
    elif isinstance(subplot_loc, str):
        prefix = f"For {subplot_loc}, "
    else:
        raise ValueError(f"Invalid subplot_loc: {subplot_loc}")
    # return the question with the subplot location
    return DESCRIPTIVE_RESP_INST[qid].format(prefix)

def build_descriptive_quries(data, image_dir):
    queries = {}
    for _, d in data.items():
        figure_path = os.path.join(image_dir, f"{d['figure_id']}.jpg")
        for i in range(len(d['qids'])):
            # mapping from template id and subplot location to the question
            question = descriptive_query_helper(d['qids'][i], d['subplot_loc'])
            curr_query = {
                'figure_id': d['figure_id'], # figure_id
                'figure_path': figure_path, # figure_path (dropped later)
                'subq_idx': i, # index of the (4) questions for the given figure
                'qid': d['qids'][i], # template id
                'question': question, # question content
            }                
            queries[f"{d['figure_id']}_{i}"] = curr_query
    return queries