Skip to content
Snippets Groups Projects
Unverified Commit f6c3ffd4 authored by Geeta Chauhan's avatar Geeta Chauhan Committed by GitHub
Browse files

Inference updates (#12)

parents 18ea0a62 557e881f
Branches
Tags
No related merge requests found
......@@ -62,13 +62,11 @@ def main(
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "</s>",
"unk_token": "</s>",
"pad_token": "[PAD]",
"pad_token": "<PAD>",
}
)
chats = format_tokens(dialogs, tokenizer)
with torch.no_grad():
......
......@@ -7,6 +7,7 @@ import fire
import torch
import os
import sys
import time
from typing import List
from transformers import LlamaTokenizer
......@@ -49,15 +50,13 @@ def main(
# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
model = load_model(model_name, quantization)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "</s>",
"unk_token": "</s>",
"pad_token": "[PAD]",
"pad_token": "<PAD>",
}
)
......@@ -88,7 +87,7 @@ def main(
batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
**batch,
......@@ -103,7 +102,8 @@ def main(
length_penalty=length_penalty,
**kwargs
)
e2e_inference_time = (time.perf_counter()-start)*1000
print(f"the inference time is {e2e_inference_time} ms")
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Safety check of the model output
......
......@@ -109,13 +109,11 @@ def main(**kwargs):
# Load the tokenizer and add special tokens
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "</s>",
"unk_token": "</s>",
"pad_token": '[PAD]',
}
)
{
"pad_token": "<PAD>",
}
)
if train_config.use_peft:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment