Skip to content
Snippets Groups Projects
Unverified Commit f6c3ffd4 authored by Geeta Chauhan's avatar Geeta Chauhan Committed by GitHub
Browse files

Inference updates (#12)

parents 18ea0a62 557e881f
No related branches found
No related tags found
No related merge requests found
...@@ -62,13 +62,11 @@ def main( ...@@ -62,13 +62,11 @@ def main(
tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": "</s>",
"bos_token": "</s>", "pad_token": "<PAD>",
"unk_token": "</s>",
"pad_token": "[PAD]",
} }
) )
chats = format_tokens(dialogs, tokenizer) chats = format_tokens(dialogs, tokenizer)
with torch.no_grad(): with torch.no_grad():
......
...@@ -7,6 +7,7 @@ import fire ...@@ -7,6 +7,7 @@ import fire
import torch import torch
import os import os
import sys import sys
import time
from typing import List from typing import List
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
...@@ -49,15 +50,13 @@ def main( ...@@ -49,15 +50,13 @@ def main(
# Set the seeds for reproducibility # Set the seeds for reproducibility
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
model = load_model(model_name, quantization) model = load_model(model_name, quantization)
tokenizer = LlamaTokenizer.from_pretrained(model_name) tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": "</s>",
"bos_token": "</s>", "pad_token": "<PAD>",
"unk_token": "</s>",
"pad_token": "[PAD]",
} }
) )
...@@ -88,7 +87,7 @@ def main( ...@@ -88,7 +87,7 @@ def main(
batch = tokenizer(user_prompt, return_tensors="pt") batch = tokenizer(user_prompt, return_tensors="pt")
batch = {k: v.to("cuda") for k, v in batch.items()} batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
outputs = model.generate( outputs = model.generate(
**batch, **batch,
...@@ -103,7 +102,8 @@ def main( ...@@ -103,7 +102,8 @@ def main(
length_penalty=length_penalty, length_penalty=length_penalty,
**kwargs **kwargs
) )
e2e_inference_time = (time.perf_counter()-start)*1000
print(f"the inference time is {e2e_inference_time} ms")
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Safety check of the model output # Safety check of the model output
......
...@@ -109,13 +109,11 @@ def main(**kwargs): ...@@ -109,13 +109,11 @@ def main(**kwargs):
# Load the tokenizer and add special tokens # Load the tokenizer and add special tokens
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
"eos_token": "</s>",
"bos_token": "</s>", "pad_token": "<PAD>",
"unk_token": "</s>", }
"pad_token": '[PAD]', )
}
)
if train_config.use_peft: if train_config.use_peft:
peft_config = generate_peft_config(train_config, kwargs) peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment