Skip to content
Snippets Groups Projects
Commit 57d82dff authored by duzx16's avatar duzx16
Browse files

Better test script for transformers t5

parent 913ce150
No related branches found
No related tags found
No related merge requests found
from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
device = 'cuda:1'
tokenizer = T5Tokenizer.from_pretrained("t5-large") tokenizer = T5Tokenizer.from_pretrained("t5-large")
model = T5Model.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
model = model.to('cuda') model = model.to(device)
model.eval() model.eval()
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to('cuda') input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to('cuda') decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
breakpoint()
output = model(input_ids=input_ids, labels=decoder_input_ids) output = model(input_ids=input_ids, labels=decoder_input_ids)
output.loss.backward() output.loss.backward()
breakpoint() a = 1
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment