Skip to content
Snippets Groups Projects
inference_sql.py 3.57 KiB
Newer Older
Jerry Liu's avatar
cr
Jerry Liu committed
from typing import Optional

from modal import gpu, method, Retries
from modal.cls import ClsMixin
import json

from .common import (
    output_vol,
    stub,
    VOL_MOUNT_PATH,
    get_data_path,
    generate_prompt_sql
)
from .inference_utils import OpenLlamaLLM


@stub.function(
    gpu="A100",
    retries=Retries(
        max_retries=3,
        initial_delay=5.0,
        backoff_coefficient=2.0,
    ),
    timeout=60 * 60 * 2,
    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
    cloud="gcp",
)
def run_evals(
    sample_data, 
    model_dir: str = "data_sql", 
    use_finetuned_model: bool = True
):
    llm = OpenLlamaLLM(
        model_dir=model_dir, max_new_tokens=256, use_finetuned_model=use_finetuned_model
    )
    inputs_outputs = []
    for row_dict in sample_data:
        prompt = generate_prompt_sql(row_dict["input"], row_dict["context"])
        completion = llm.complete(
            prompt,
            do_sample=True,
            temperature=0.3,
            top_p=0.85,
            top_k=40,
            num_beams=1,
            max_new_tokens=600,
            repetition_penalty=1.2,
        )
        inputs_outputs.append((row_dict, completion.text))
    return inputs_outputs


@stub.function(
    gpu="A100",
    retries=Retries(
        max_retries=3,
        initial_delay=5.0,
        backoff_coefficient=2.0,
    ),
    timeout=60 * 60 * 2,
    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
    cloud="gcp",
)
def run_evals_all(
    data_dir: str = "data_sql", 
    model_dir: str = "data_sql", 
    num_samples: int = 10, 
):
    # evaluate a sample from the same training set
    from datasets import load_dataset

    data_path = get_data_path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)

    # load sample data
    sample_data = data["train"].shuffle().select(range(num_samples))

    print('*** Running inference with finetuned model ***')
    inputs_outputs_0 = run_evals.call(
        sample_data=sample_data, 
        model_dir=model_dir, 
        use_finetuned_model=True
    )

    print('*** Running inference with base model ***')
    input_outputs_1 = run_evals.call(
        sample_data=sample_data, 
        model_dir=model_dir, 
        use_finetuned_model=False
    )

    return inputs_outputs_0, input_outputs_1



@stub.local_entrypoint()
def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: int = 10):
    """Main function."""
    inputs_outputs_0, input_outputs_1 = run_evals_all.call(
        data_dir=data_dir,
        model_dir=model_dir,
        num_samples=num_samples
    )
    for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
        print(f"Input {idx}: " + str(row_dict))
        print(f"Output {idx} (finetuned model): " + str(completion))
        print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))

    # print('*** Running inference with finetuned model ***')
    # inputs_outputs_0 = run_evals.call(
    #     data_dir=data_dir, 
    #     model_dir=model_dir, 
    #     num_samples=num_samples, 
    #     use_finetuned_model=True
    # )
    # for row_dict, completion in inputs_outputs_0:
    #     print("Input: " + str(row_dict))
    #     print("Output: " + str(completion))

    # print('*** Running inference with base model ***')
    # input_outputs_1 = run_evals.call(
    #     data_dir=data_dir, 
    #     model_dir=model_dir, 
    #     num_samples=num_samples, 
    #     use_finetuned_model=False
    # )
    # for row_dict, completion in input_outputs_1:
    #     print("Input: " + str(row_dict))
    #     print("Output: " + str(completion))