Skip to content
Snippets Groups Projects
Commit 5faaff4a authored by Jerry Liu's avatar Jerry Liu
Browse files

cr

parent 05f99498
No related branches found
No related tags found
No related merge requests found
......@@ -11,9 +11,9 @@ from pathlib import Path
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp"
)
def load_model():
def load_model(model_dir: str = "data_sql"):
"""Load model."""
path = get_model_path()
path = get_model_path(model_dir=model_dir)
config_path = path / "adapter_config.json"
model_path = path / "adapter_model.bin"
......@@ -28,9 +28,9 @@ def load_model():
stub.model_dict["model"] = model_data
@stub.local_entrypoint()
def main(output_dir: str):
def main(output_dir: str, model_dir: str = "data_sql"):
# copy adapter_config.json and adapter_model.bin files into dict
load_model.call()
load_model.call(model_dir=model_dir)
model_data = stub.model_dict["model"]
config_data = stub.model_dict["config"]
......
from typing import Optional
from modal import gpu, method, Retries
from modal.cls import ClsMixin
import json
from .common import (
output_vol,
stub,
VOL_MOUNT_PATH,
get_data_path,
generate_prompt_sql
)
from .inference_utils import OpenLlamaLLM
@stub.function(
gpu="A100",
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def run_evals(
sample_data,
model_dir: str = "data_sql",
use_finetuned_model: bool = True
):
llm = OpenLlamaLLM(
model_dir=model_dir, max_new_tokens=256, use_finetuned_model=use_finetuned_model
)
inputs_outputs = []
for row_dict in sample_data:
prompt = generate_prompt_sql(row_dict["input"], row_dict["context"])
completion = llm.complete(
prompt,
do_sample=True,
temperature=0.3,
top_p=0.85,
top_k=40,
num_beams=1,
max_new_tokens=600,
repetition_penalty=1.2,
)
inputs_outputs.append((row_dict, completion.text))
return inputs_outputs
@stub.function(
gpu="A100",
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def run_evals_all(
data_dir: str = "data_sql",
model_dir: str = "data_sql",
num_samples: int = 10,
):
# evaluate a sample from the same training set
from datasets import load_dataset
data_path = get_data_path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
# load sample data
sample_data = data["train"].shuffle().select(range(num_samples))
print('*** Running inference with finetuned model ***')
inputs_outputs_0 = run_evals.call(
sample_data=sample_data,
model_dir=model_dir,
use_finetuned_model=True
)
print('*** Running inference with base model ***')
input_outputs_1 = run_evals.call(
sample_data=sample_data,
model_dir=model_dir,
use_finetuned_model=False
)
return inputs_outputs_0, input_outputs_1
@stub.local_entrypoint()
def main(data_dir: str = "data_sql", model_dir: str = "data_sql", num_samples: int = 10):
"""Main function."""
inputs_outputs_0, input_outputs_1 = run_evals_all.call(
data_dir=data_dir,
model_dir=model_dir,
num_samples=num_samples
)
for idx, (row_dict, completion) in enumerate(inputs_outputs_0):
print(f"Input {idx}: " + str(row_dict))
print(f"Output {idx} (finetuned model): " + str(completion))
print(f"Output {idx} (base model): " + str(input_outputs_1[idx][1]))
# print('*** Running inference with finetuned model ***')
# inputs_outputs_0 = run_evals.call(
# data_dir=data_dir,
# model_dir=model_dir,
# num_samples=num_samples,
# use_finetuned_model=True
# )
# for row_dict, completion in inputs_outputs_0:
# print("Input: " + str(row_dict))
# print("Output: " + str(completion))
# print('*** Running inference with base model ***')
# input_outputs_1 = run_evals.call(
# data_dir=data_dir,
# model_dir=model_dir,
# num_samples=num_samples,
# use_finetuned_model=False
# )
# for row_dict, completion in input_outputs_1:
# print("Input: " + str(row_dict))
# print("Output: " + str(completion))
\ No newline at end of file
......@@ -12,6 +12,7 @@ from .common import (
get_model_path,
generate_prompt_sql
)
from .inference_utils import OpenLlamaLLM
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
......@@ -26,103 +27,6 @@ from llama_index import SQLDatabase, ServiceContext, Prompt
from typing import Any
@stub.cls(
gpu=gpu.A100(memory=20),
network_file_systems={VOL_MOUNT_PATH: output_vol},
)
class OpenLlamaLLM(CustomLLM, ClsMixin):
"""OpenLlamaLLM is a custom LLM that uses the OpenLlamaModel."""
def __init__(
self,
model_dir: str = "data_sql",
max_new_tokens: int = 128,
callback_manager: Optional[CallbackManager] = None,
use_finetuned_model: bool = True,
):
super().__init__(callback_manager=callback_manager)
import sys
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
CHECKPOINT = get_model_path(model_dir)
load_8bit = False
device = "cuda"
self.tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
if use_finetuned_model:
model = PeftModel.from_pretrained(
model,
CHECKPOINT,
torch_dtype=torch.float16,
)
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
self.model = model
self.device = device
self._max_new_tokens = max_new_tokens
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=2048,
num_output=self._max_new_tokens,
model_name="finetuned_openllama_sql"
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
import torch
from transformers import GenerationConfig
# TODO: TO fill
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
# print(tokens)
generation_config = GenerationConfig(
**kwargs,
)
with torch.no_grad():
generation_output = self.model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=self._max_new_tokens,
)
s = generation_output.sequences[0]
output = self.tokenizer.decode(s, skip_special_tokens=True)
# NOTE: parsing response this way means that the model can mostly
# only be used for text-to-SQL, not other purposes
response_text = output.split("### Response:")[1].strip()
return CompletionResponse(text=response_text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
@stub.function(
gpu="A100",
retries=Retries(
......@@ -170,20 +74,25 @@ def run_query(query: str, model_dir: str = "data_sql", use_finetuned_model: bool
)
response = query_engine.query(query)
print(f'Model output: {str(response.metadata["sql_query"])}')
print(
f'Model output: \n'
f'SQL Query: {str(response.metadata["sql_query"])}'
f"Response: {response.response}"
)
return response
@stub.local_entrypoint()
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: Optional[bool] = None):
def main(query: str, sqlite_file_path: str, model_dir: str = "data_sql", use_finetuned_model: str = "True"):
"""Main function."""
fp = open(sqlite_file_path, "rb")
stub.data_dict["sqlite_data"] = fp.read()
if use_finetuned_model is None:
if use_finetuned_model == "None":
# try both
run_query.call(query, model_dir=model_dir, use_finetuned_model=True)
run_query.call(query, model_dir=model_dir, use_finetuned_model=False)
else:
run_query.call(query, model_dir=model_dir, use_finetuned_model=use_finetuned_model)
bool_toggle = use_finetuned_model == "True"
run_query.call(query, model_dir=model_dir, use_finetuned_model=bool_toggle)
"""Get inference utils."""
from typing import Optional
from modal import gpu
from modal.cls import ClsMixin
from .common import (
MODEL_PATH,
output_vol,
stub,
VOL_MOUNT_PATH,
get_model_path,
)
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
CustomLLM,
LLMMetadata,
CompletionResponse,
CompletionResponseGen,
)
from llama_index.llms.base import llm_completion_callback
from typing import Any
@stub.cls(
gpu=gpu.A100(memory=20),
network_file_systems={VOL_MOUNT_PATH: output_vol},
)
class OpenLlamaLLM(CustomLLM, ClsMixin):
"""OpenLlamaLLM is a custom LLM that uses the OpenLlamaModel."""
def __init__(
self,
model_dir: str = "data_sql",
max_new_tokens: int = 128,
callback_manager: Optional[CallbackManager] = None,
use_finetuned_model: bool = True,
):
super().__init__(callback_manager=callback_manager)
import sys
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
CHECKPOINT = get_model_path(model_dir)
load_8bit = False
device = "cuda"
self.tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
if use_finetuned_model:
model = PeftModel.from_pretrained(
model,
CHECKPOINT,
torch_dtype=torch.float16,
)
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
self.model = model
self.device = device
self._max_new_tokens = max_new_tokens
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=2048,
num_output=self._max_new_tokens,
model_name="finetuned_openllama_sql"
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
import torch
from transformers import GenerationConfig
# TODO: TO fill
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
# print(tokens)
generation_config = GenerationConfig(
**kwargs,
)
with torch.no_grad():
generation_output = self.model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=self._max_new_tokens,
)
s = generation_output.sequences[0]
output = self.tokenizer.decode(s, skip_special_tokens=True)
# NOTE: parsing response this way means that the model can mostly
# only be used for text-to-SQL, not other purposes
response_text = output.split("### Response:")[1].strip()
return CompletionResponse(text=response_text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment