Skip to content
Snippets Groups Projects
Commit 56936ad9 authored by Ming Ding's avatar Ming Ding
Browse files

Merge branch 'finer-attn-hooks' of github.com:THUDM/SwissArmyTransformer into finer-attn-hooks

parents bcbc6120 ac77e3bf
No related branches found
No related tags found
No related merge requests found
...@@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin): ...@@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin):
super().__init__() super().__init__()
@non_conflict @non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args): def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size **kw_args):
b, nh, seq_len, hidden_size = k.shape if not cross_attention:
mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
b, nh, seq_len, hidden_size = k.shape
cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
kw_args['output_this_layer']['mem_kv'] = cache_kv kw_args['output_this_layer']['mem_kv'] = cache_kv
if mem is not None: # the first time, mem is None if mem is not None: # the first time, mem is None
# might change batch_size # might change batch_size
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
memk, memv = mem[0], mem[1] memk, memv = mem[0], mem[1]
k = torch.cat((memk, k), dim=2) k = torch.cat((memk, k), dim=2)
v = torch.cat((memv, v), dim=2) v = torch.cat((memv, v), dim=2)
return old_impl(q, k, v, mask, dropout_fn, **kw_args) return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
class CachedAutoregressiveModel(BaseModel): class CachedAutoregressiveModel(BaseModel):
......
...@@ -3,10 +3,12 @@ import torch ...@@ -3,10 +3,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from .mixins import BaseMixin from .mixins import BaseMixin
from .encoder_decoder_model import EncoderDecoderModel from .encoder_decoder_model import EncoderDecoderModel
from .base_model import non_conflict
from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu import get_model_parallel_world_size
from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
class T5PositionEmbeddingMixin(BaseMixin): class T5PositionEmbeddingMixin(BaseMixin):
...@@ -94,7 +96,7 @@ class T5AttentionMixin(BaseMixin): ...@@ -94,7 +96,7 @@ class T5AttentionMixin(BaseMixin):
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, cross_attention=False): def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None] context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :] memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
...@@ -106,84 +108,88 @@ class T5AttentionMixin(BaseMixin): ...@@ -106,84 +108,88 @@ class T5AttentionMixin(BaseMixin):
) )
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
# shape (query_length, key_length, num_heads) # shape (query_length, key_length, num_heads)
if cross_attention: values = self.relative_attention_bias(relative_position_bucket)
values = self.cross_relative_attention_bias(relative_position_bucket)
else:
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values return values
def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args): @non_conflict
attn_module = self.transformer.layers[layer_id].attention def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
seq_length = hidden_states.size(1) cross_attention=False, **kw_args):
memory_length = mems[layer_id].size(1) if mems else 0 log_attention_weights = None
mixed_raw_layer = attn_module.query_key_value(hidden_states) if not cross_attention:
(mixed_query_layer, if position_bias is None:
mixed_key_layer, seq_length = q.size(2)
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) key_length = k.size(2)
position_bias = self.compute_bias(key_length, key_length)
dropout_fn = attn_module.attention_dropout if attn_module.training else None position_bias = position_bias[:, :, -seq_length:, :]
kw_args['output_cross_layer']['position_bias'] = position_bias
query_layer = attn_module._transpose_for_scores(mixed_query_layer) log_attention_weights = position_bias
key_layer = attn_module._transpose_for_scores(mixed_key_layer) return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
value_layer = attn_module._transpose_for_scores(mixed_value_layer) log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args)
if position_bias is None:
position_bias = self.compute_bias(seq_length, memory_length + seq_length)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
log_attention_weights=position_bias, scaling_attention_score=False)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attn_module.dense(context_layer)
if attn_module.training:
output = attn_module.output_dropout(output)
kw_args['output_cross_layer']['position_bias'] = position_bias
return output
def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
**kw_args):
attn_module = self.transformer.layers[layer_id].cross_attention
mixed_query_layer = attn_module.query(hidden_states)
mixed_x_layer = attn_module.key_value(encoder_outputs)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
dropout_fn = attn_module.attention_dropout if attn_module.training else None
# Reshape and transpose [b, np, s, hn]
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
scaling_attention_score=False)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = attn_module.dense(context_layer)
if attn_module.training:
output = attn_module.output_dropout(output)
return output
class T5DecoderFinalMixin(BaseMixin): class T5DecoderFinalMixin(BaseMixin):
def __init__(self, hidden_size): def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.tie_word_embeddings = tie_word_embeddings
if not tie_word_embeddings:
self.lm_head = VocabParallelEmbedding(
vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = logits_parallel * (self.hidden_size ** -0.5) if self.tie_word_embeddings:
logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
else:
logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
return logits_parallel return logits_parallel
def t5_gelu(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5GatedGeluMLPMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
super().__init__()
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.init_method_std = init_method_std
self.gated_h_to_4h_list = torch.nn.ModuleList([
ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=self._init_weights,
bias=bias,
module=self,
name="gated_h_to_4h"
)
for layer_id in range(num_layers)])
def _init_weights(self, weight, **kwargs):
torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
mlp_module = self.transformer.layers[layer_id].mlp
hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
hidden_states = hidden_gelu * hidden_linear
output = mlp_module.dense_4h_to_h(hidden_states)
if self.training:
output = mlp_module.dropout(output)
return output
class T5Model(EncoderDecoderModel): class T5Model(EncoderDecoderModel):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
self.init_method_std = args.init_method_std self.init_method_std = args.init_method_std
...@@ -205,10 +211,19 @@ class T5Model(EncoderDecoderModel): ...@@ -205,10 +211,19 @@ class T5Model(EncoderDecoderModel):
"t5-position", T5PositionEmbeddingMixin() "t5-position", T5PositionEmbeddingMixin()
) )
self.decoder.add_mixin( self.decoder.add_mixin(
"t5-final", T5DecoderFinalMixin(args.hidden_size) "t5-final",
T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
) )
del self.decoder.transformer.position_embeddings del self.decoder.transformer.position_embeddings
if args.gated_gelu_mlp:
self.encoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
self.decoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
def _init_weights(self, weight, module, name): def _init_weights(self, weight, module, name):
init_method_std = self.init_method_std init_method_std = self.init_method_std
...@@ -246,6 +261,8 @@ class T5Model(EncoderDecoderModel): ...@@ -246,6 +261,8 @@ class T5Model(EncoderDecoderModel):
super().add_model_specific_args(parser) super().add_model_specific_args(parser)
parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
parser.add_argument("--init-method-std", type=float, default=0.02) parser.add_argument("--init-method-std", type=float, default=0.02)
parser.add_argument("--gated-gelu-mlp", action='store_true')
parser.add_argument("--no-share-embeddings", action='store_true')
def encode(self, input_ids, attention_mask=None, **kw_args): def encode(self, input_ids, attention_mask=None, **kw_args):
return super().encode(input_ids, None, attention_mask, **kw_args) return super().encode(input_ids, None, attention_mask, **kw_args)
...@@ -254,7 +271,7 @@ class T5Model(EncoderDecoderModel): ...@@ -254,7 +271,7 @@ class T5Model(EncoderDecoderModel):
return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs,
cross_attention_mask=cross_attention_mask, **kw_args) cross_attention_mask=cross_attention_mask, **kw_args)
def forward(self, enc_input_ids, dec_input_ids, dec_attention_mask, *, enc_attention_mask=None, def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None,
cross_attention_mask=None, **kw_args): cross_attention_mask=None, **kw_args):
batch_size, seq_length = enc_input_ids.size()[:2] batch_size, seq_length = enc_input_ids.size()[:2]
if enc_attention_mask is None: if enc_attention_mask is None:
......
...@@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module): ...@@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module):
if 'cross_attention_forward' in self.hooks: if 'cross_attention_forward' in self.hooks:
return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args) return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
else: else:
attention_fn = standard_attention
if 'attention_fn' in self.hooks:
attention_fn = self.hooks['attention_fn']
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_x_layer = self.key_value(encoder_outputs) mixed_x_layer = self.key_value(encoder_outputs)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
...@@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module): ...@@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer) key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer) value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn) context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
cross_attention=True, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
# [b, s, hp] # [b, s, hp]
...@@ -394,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module): ...@@ -394,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module):
if self.is_decoder: if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs'] encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None: if encoder_outputs is not None:
cross_attention_mask = kw_args['cross_attention_mask'] assert 'cross_attention_mask' in kw_args
# Cross attention # Cross attention
attention_output = self.cross_attention(layernorm_output, **kw_args) attention_output = self.cross_attention(layernorm_output, **kw_args)
# Residual connection. # Residual connection.
...@@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module):
) # None means full attention ) # None means full attention
assert len(attention_mask.shape) == 2 or \ assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
# embedding part # embedding part
if 'word_embedding_forward' in self.hooks: if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
...@@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module):
output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args) output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
else: else:
output_cross_layer = {} output_cross_layer = {}
output_per_layers = [] output_per_layers = []
if self.checkpoint_activations: if self.checkpoint_activations:
# define custom_forward for checkpointing # define custom_forward for checkpointing
...@@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module):
def custom_forward(*inputs): def custom_forward(*inputs):
layers_ = self.layers[start:end] layers_ = self.layers[start:end]
x_, mask = inputs[0], inputs[1] x_, mask = inputs[0], inputs[1]
# recover kw_args and output_cross_layer # recover kw_args and output_cross_layer
flat_inputs = inputs[2:] flat_inputs = inputs[2:]
kw_args, output_cross_layer = {}, {} kw_args, output_cross_layer = {}, {}
...@@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module): ...@@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module):
for k, idx in cross_layer_index.items(): for k, idx in cross_layer_index.items():
output_cross_layer[k] = flat_inputs[idx] output_cross_layer[k] = flat_inputs[idx]
# ----------------- # -----------------
output_per_layers_part = [] output_per_layers_part = []
for i, layer in enumerate(layers_): for i, layer in enumerate(layers_):
if 'layer_forward' in self.hooks: if 'layer_forward' in self.hooks:
layer_ret = self.hooks['layer_forward']( layer_ret = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer, **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={} output_this_layer={}, output_cross_layer={}
) )
else: else:
layer_ret = layer( layer_ret = layer(
x_, mask, layer_id=layer.layer_id, x_, mask, layer_id=layer.layer_id,
**kw_args, **output_cross_layer, **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={} output_this_layer={}, output_cross_layer={}
) )
if torch.is_tensor(layer_ret): # only hidden_states if torch.is_tensor(layer_ret): # only hidden_states
...@@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module):
elif len(layer_ret) == 2: # hidden_states & output_this_layer elif len(layer_ret) == 2: # hidden_states & output_this_layer
x_, output_this_layer = layer_ret x_, output_this_layer = layer_ret
output_cross_layer = {} output_cross_layer = {}
elif len(layer_ret) == 3: elif len(layer_ret) == 3:
x_, output_this_layer, output_cross_layer = layer_ret x_, output_this_layer, output_cross_layer = layer_ret
assert isinstance(output_this_layer, dict) assert isinstance(output_this_layer, dict)
assert isinstance(output_cross_layer, dict) assert isinstance(output_cross_layer, dict)
...@@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module):
# To save memory when only finetuning the final layers, don't use checkpointing. # To save memory when only finetuning the final layers, don't use checkpointing.
if self.training: if self.training:
hidden_states.requires_grad_(True) hidden_states.requires_grad_(True)
l, num_layers = 0, len(self.layers) l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers chunk_length = self.checkpoint_num_layers
output_this_layer = [] output_this_layer = []
...@@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module): ...@@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module):
args = [hidden_states, attention_mask] args = [hidden_states, attention_mask]
if 'layer_forward' in self.hooks: # customized layer_forward if 'layer_forward' in self.hooks: # customized layer_forward
layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
**kw_args, **kw_args,
**output_cross_layer, **output_cross_layer,
output_this_layer={}, output_cross_layer={} output_this_layer={}, output_cross_layer={}
) )
else: else:
layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
output_this_layer={}, output_cross_layer={}) output_this_layer={}, output_cross_layer={})
if torch.is_tensor(layer_ret): # only hidden_states if torch.is_tensor(layer_ret): # only hidden_states
hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {} hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
elif len(layer_ret) == 2: # hidden_states & output_this_layer elif len(layer_ret) == 2: # hidden_states & output_this_layer
hidden_states, output_this_layer = layer_ret hidden_states, output_this_layer = layer_ret
output_cross_layer = {} output_cross_layer = {}
elif len(layer_ret) == 3: elif len(layer_ret) == 3:
hidden_states, output_this_layer, output_cross_layer = layer_ret hidden_states, output_this_layer, output_cross_layer = layer_ret
if output_hidden_states: if output_hidden_states:
......
from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") device = 'cuda:1'
model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") tokenizer = T5Tokenizer.from_pretrained("t5-large")
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids model = model.to(device)
output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) model.eval()
breakpoint() input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
\ No newline at end of file decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
breakpoint()
output = model(input_ids=input_ids, labels=decoder_input_ids)
output.loss.backward()
a = 1
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment