Skip to content
Snippets Groups Projects
Commit 0d2385f3 authored by duzx16's avatar duzx16
Browse files

Merge branch 'enc-dec' into finer-attn-hooks

# Conflicts:
#	SwissArmyTransformer/generation/autoregressive_sampling.py
#	SwissArmyTransformer/model/cached_autoregressive_model.py
#	SwissArmyTransformer/model/t5_model.py
#	examples/t5/test_t5.py
parents 7c573bbd 308e87b0
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel ...@@ -6,7 +6,8 @@ from .encoder_decoder_model import EncoderDecoderModel
from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu import get_model_parallel_world_size
from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim from SwissArmyTransformer.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method
from SwissArmyTransformer.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding
class T5PositionEmbeddingMixin(BaseMixin): class T5PositionEmbeddingMixin(BaseMixin):
...@@ -142,7 +143,7 @@ class T5AttentionMixin(BaseMixin): ...@@ -142,7 +143,7 @@ class T5AttentionMixin(BaseMixin):
kw_args['output_cross_layer']['position_bias'] = position_bias kw_args['output_cross_layer']['position_bias'] = position_bias
return output return output
def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args, def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
**kw_args): **kw_args):
...@@ -173,17 +174,63 @@ class T5AttentionMixin(BaseMixin): ...@@ -173,17 +174,63 @@ class T5AttentionMixin(BaseMixin):
class T5DecoderFinalMixin(BaseMixin): class T5DecoderFinalMixin(BaseMixin):
def __init__(self, hidden_size): def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = logits_parallel * (self.hidden_size ** -0.5) if self.tie_word_embeddings:
logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) logits_parallel = logits_parallel * (self.hidden_size ** -0.5)
logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight)
else:
logits_parallel = F.linear(logits_parallel, self.lm_head.weight)
return logits_parallel return logits_parallel
def t5_gelu(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5GatedGeluMLPMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02):
super().__init__()
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.init_method_std = init_method_std
self.gated_h_to_4h_list = torch.nn.ModuleList([
ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=self._init_weights,
bias=bias,
module=self,
name="gated_h_to_4h"
)
for layer_id in range(num_layers)])
def _init_weights(self, weight, **kwargs):
torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5))
def mlp_forward(self, hidden_states, layer_id=None, **kw_args):
mlp_module = self.transformer.layers[layer_id].mlp
hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states))
hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states)
hidden_states = hidden_gelu * hidden_linear
output = mlp_module.dense_4h_to_h(hidden_states)
if self.training:
output = mlp_module.dropout(output)
return output
class T5Model(EncoderDecoderModel): class T5Model(EncoderDecoderModel):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
self.init_method_std = args.init_method_std self.init_method_std = args.init_method_std
...@@ -205,10 +252,19 @@ class T5Model(EncoderDecoderModel): ...@@ -205,10 +252,19 @@ class T5Model(EncoderDecoderModel):
"t5-position", T5PositionEmbeddingMixin() "t5-position", T5PositionEmbeddingMixin()
) )
self.decoder.add_mixin( self.decoder.add_mixin(
"t5-final", T5DecoderFinalMixin(args.hidden_size) "t5-final",
T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings)
) )
del self.decoder.transformer.position_embeddings del self.decoder.transformer.position_embeddings
if args.gated_gelu_mlp:
self.encoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
self.decoder.add_mixin(
"gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std,
inner_hidden_size=args.inner_hidden_size, bias=False)
)
def _init_weights(self, weight, module, name): def _init_weights(self, weight, module, name):
init_method_std = self.init_method_std init_method_std = self.init_method_std
...@@ -246,6 +302,8 @@ class T5Model(EncoderDecoderModel): ...@@ -246,6 +302,8 @@ class T5Model(EncoderDecoderModel):
super().add_model_specific_args(parser) super().add_model_specific_args(parser)
parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--relative-attention-num-buckets", type=int, default=None)
parser.add_argument("--init-method-std", type=float, default=0.02) parser.add_argument("--init-method-std", type=float, default=0.02)
parser.add_argument("--gated-gelu-mlp", action='store_true')
parser.add_argument("--no-share-embeddings", action='store_true')
def encode(self, input_ids, attention_mask=None, **kw_args): def encode(self, input_ids, attention_mask=None, **kw_args):
return super().encode(input_ids, None, attention_mask, **kw_args) return super().encode(input_ids, None, attention_mask, **kw_args)
......
...@@ -158,7 +158,7 @@ class SelfAttention(torch.nn.Module): ...@@ -158,7 +158,7 @@ class SelfAttention(torch.nn.Module):
if self.training: if self.training:
output = self.output_dropout(output) output = self.output_dropout(output)
return output return output, None
class CrossAttention(torch.nn.Module): class CrossAttention(torch.nn.Module):
......
from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") device = 'cuda:1'
model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large") tokenizer = T5Tokenizer.from_pretrained("t5-large")
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-xl-lm-adapt")
decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids model = model.to(device)
output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) model.eval()
breakpoint() input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to(device)
\ No newline at end of file decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to(device)
breakpoint()
output = model(input_ids=input_ids, labels=decoder_input_ids)
output.loss.backward()
a = 1
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment