Skip to content
Snippets Groups Projects
descriptive_utils.py 7.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • Colin Wang's avatar
    Colin Wang committed
    import os, json
    from constants import DESCRIPTIVE_RESP_INST, DESCRIPTIVE_GRADING_PREFIX, \
                DESCRIPTIVE_GRADING_QMAP, DESCRIPTIVE_GRADING_ICL
    
    def get_rubric(qid):
        instruction = None
        if qid in [1]:
            instruction = DESCRIPTIVE_GRADING_ICL['title']
        if qid in [2, 3, 4, 5, 6, 7]:
            instruction = DESCRIPTIVE_GRADING_ICL['ocr']
        if qid in [8, 9, 10, 12, 14, 15, 17, 19]:
            instruction = DESCRIPTIVE_GRADING_ICL['quant']
        if qid in [11]:
            instruction = DESCRIPTIVE_GRADING_ICL['bool']
        if qid in [13]:
            instruction = DESCRIPTIVE_GRADING_ICL['enum']
        if qid in [16]:
            instruction = DESCRIPTIVE_GRADING_ICL['trend']
        if qid in [18]:
            instruction = DESCRIPTIVE_GRADING_ICL['layout']
        assert instruction is not None, f"Instruction for qid {qid} is not found."
        return instruction
    
    def get_descriptive_result_gpt(client, prompt, length, max_retries=10):
        curr_retries = 0
        max_tokens = 256
        while curr_retries < max_retries:
            try:
                response = client.chat.completions.create(
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        }
                    ],
    
    Colin Wang's avatar
    Colin Wang committed
                    model="gpt-4o-2024-05-13",
    
    Colin Wang's avatar
    Colin Wang committed
                    response_format={"type": "json_object"},
                    n=1,
                    max_tokens=max_tokens,
                    temperature=0,
                    top_p=1,
                    seed=42,
                ).choices[0].message.content
                content = json.loads(response)
                verify_grading_output(content, length)
                break
            except Exception as e:
                print(f"Error: {e}")
                # increase the max_tokens if the response is too long
                if 'Unterminated string starting at' in str(e):
                    if max_tokens >= 1024:
                        print(f"Failed to get response for prompt: {prompt}")
                        content = build_dummy_output(length)
                        break
                    else:
                        max_tokens = min(1024, max_tokens * 2) # double the max_tokens
                        print(f"Retrying with max_tokens: {max_tokens}")
                # otherwise, retry the request
                curr_retries += 1
        # if failed to get response, return dummy data
        if curr_retries == max_retries:
            print(f"Failed to get response for prompt: {prompt}")
            content = build_dummy_output(length)
        return content
    
    def build_json_keys(length):
        keys = []
        # specify the keys for gpt-4o's json response
        for i in range(1, length+1):
            keys.append(f"extract_answer_T{i}")
            keys.append(f"score_T{i}")
        return str(keys)
    
    def populate_grading_inputs(batch):
        query = ""
        for i, (_, response, answer) in enumerate(batch):
            # index, response, answer
            curr_query = "T{}:\nResponse {}: {}\nGround Truth {}: {}\n\n"\
                .format(i+1, i+1, response, i+1, answer)
            query += curr_query
        return query
    
    def verify_grading_output(data, length_data):
        # check the integrity of keys and values
        for i in range(1, length_data+1):
            assert f"extract_answer_T{i}" in data, f"extract_answer_T{i} is not found in {d}"
            assert f"score_T{i}" in data, f"score_T{i} is not found in {data}"
            assert data[f"score_T{i}"] in [0, 1], f"score_T{i} is not in [0, 1]"
        return True
    
    def build_dummy_output(length_data):
        # if failed to parse the response, return dummy data
        data = {}
        for i in range(1, length_data+1):
            data[f"extract_answer_T{i}"] = "Failed to parse response"
            data[f"score_T{i}"] = -1
        return data
    
    def preprocess_descriptive_grading_queries(input, resp, num_templates=19):
        # group the responses based on the template id instead of figure id
        groups = {i: [] for i in range(1, num_templates + 1)}
        for _, data in input.items():
            figure_id = data['figure_id']
            qids = data['qids']
            for i, qid in enumerate(qids):
                # figure_id with question index
                resp_key = f"{figure_id}_{i}"
                response = resp[resp_key]['response']
                answer = data['answers'][i]
                groups[qid].append((resp_key, response, answer))
        return groups
    
    def build_descriptive_grading_queries(groups, nq_per_query=5):
        queries = []
        for qid, data in groups.items():
            # batched evaluation based on number of questions per query (nq_per_query)
            for i in range(0, len(data), nq_per_query):
                # batch: list of tuples (resp_key, response, answer)
                batch = data[i : i + nq_per_query]
                # question based on the template id
                question = DESCRIPTIVE_GRADING_QMAP[qid]
                # build the json keys for GPT-4o's response
                json_keys = build_json_keys(len(batch))
                # populate batch size, question, and json keys spec
                prefix = DESCRIPTIVE_GRADING_PREFIX\
                    .replace("<|NUM_TRIPLETS|>", str(len(batch)))\
                    .replace("<|OVERARCHING_QUESTION|>", question)\
                    .replace("<|JSON_KEYS|>", json_keys)
                # add in-context grading example based on the template id
                rubric_icl = get_rubric(qid)
                # prompt + example + model responses
                grading_query = prefix + rubric_icl + populate_grading_inputs(batch)
                curr_query = {
                    'resp_keys': [d[0] for d in batch],
                    'grading_query': grading_query,
                }
                queries.append(curr_query)
        return queries
    
    def postprocess_descriptive_grading_queries(queries):
        scores = {}
        for query in queries:
            # query contains resp_keys, grading_query, extract_answer and score
            resp_keys = query['resp_keys']
            for i, resp_key in enumerate(resp_keys):
                # extract the answer and score for each response key
                extracted_answer = query[f"extract_answer_T{i+1}"]
                score = query[f"score_T{i+1}"]
                # store the extracted answer and score
                scores[resp_key] = {
                    'resp_id': resp_key,
                    'extracted_answer': extracted_answer,
                    'score': score,
                }
        return scores
    
    def descriptive_query_helper(qid, subplot_loc):
        if qid in [18, 19]:
            # skip subplot location when asking about the layout of the subplots
            return DESCRIPTIVE_RESP_INST[qid] 
        if isinstance(subplot_loc, list):
            if subplot_loc[0] == 0:
                # when there is only one subplot
                prefix = "For the current plot, "
            else:
                # when there are multiple subplots
                prefix = f"For the subplot at row {subplot_loc[0]} and column {subplot_loc[1]}, "
        # when subplots do not form a grid
        elif isinstance(subplot_loc, str):
            prefix = f"For {subplot_loc}, "
        else:
            raise ValueError(f"Invalid subplot_loc: {subplot_loc}")
        # return the question with the subplot location
        return DESCRIPTIVE_RESP_INST[qid].format(prefix)
    
    def build_descriptive_quries(data, image_dir):
        queries = {}
        for _, d in data.items():
            figure_path = os.path.join(image_dir, f"{d['figure_id']}.jpg")
            for i in range(len(d['qids'])):
                # mapping from template id and subplot location to the question
                question = descriptive_query_helper(d['qids'][i], d['subplot_loc'])
                curr_query = {
                    'figure_id': d['figure_id'], # figure_id
                    'figure_path': figure_path, # figure_path (dropped later)
                    'subq_idx': i, # index of the (4) questions for the given figure
                    'qid': d['qids'][i], # template id
                    'question': question, # question content
                }                
                queries[f"{d['figure_id']}_{i}"] = curr_query
        return queries