Skip to content
Snippets Groups Projects
Commit 7aee8e24 authored by duzx16's avatar duzx16
Browse files

Reformat code

parent c761af73
No related branches found
No related tags found
No related merge requests found
...@@ -18,40 +18,44 @@ from SwissArmyTransformer.mpu.transformer import unscaled_init_method ...@@ -18,40 +18,44 @@ from SwissArmyTransformer.mpu.transformer import unscaled_init_method
from .base_model import BaseMixin from .base_model import BaseMixin
from .cached_autoregressive_model import CachedAutoregressiveMixin from .cached_autoregressive_model import CachedAutoregressiveMixin
class PositionEmbeddingMixin(BaseMixin): class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size, def __init__(self, additional_sequence_length, hidden_size,
init_method_std=0.02, reinit_slice=slice(-1024, None) init_method_std=0.02, reinit_slice=slice(-1024, None)
): ):
super(PositionEmbeddingMixin, self).__init__() super(PositionEmbeddingMixin, self).__init__()
self.reinit_slice = reinit_slice self.reinit_slice = reinit_slice
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
def reinit(self, *pre_mixins): def reinit(self, *pre_mixins):
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
old_len, hidden_size = old_weights.shape old_len, hidden_size = old_weights.shape
assert hidden_size == self.position_embeddings.weight.shape[-1] assert hidden_size == self.position_embeddings.weight.shape[-1]
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
class AttentionMixin(BaseMixin): class AttentionMixin(BaseMixin):
def __init__(self, num_layers, def __init__(self, num_layers,
hidden_size, hidden_size,
init_method=unscaled_init_method(0.02), init_method=unscaled_init_method(0.02),
output_layer_init_method=unscaled_init_method(0.02) output_layer_init_method=unscaled_init_method(0.02)
): ):
super(AttentionMixin, self).__init__() super(AttentionMixin, self).__init__()
self.num_layers = num_layers # replace attention in the LAST n layers self.num_layers = num_layers # replace attention in the LAST n layers
self.query_key_value = torch.nn.ModuleList( self.query_key_value = torch.nn.ModuleList(
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
gather_output=False,init_method=init_method) gather_output=False, init_method=init_method)
for layer_id in range(num_layers) for layer_id in range(num_layers)
]) ])
self.dense = torch.nn.ModuleList( self.dense = torch.nn.ModuleList(
[RowParallelLinear(hidden_size, [RowParallelLinear(hidden_size,
hidden_size, hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method) init_method=output_layer_init_method)
for layer_id in range(num_layers) for layer_id in range(num_layers)
]) ])
def reinit(self, *pre_mixins): def reinit(self, *pre_mixins):
start_layer = len(self.transformer.layers) - self.num_layers start_layer = len(self.transformer.layers) - self.num_layers
assert start_layer >= 0 assert start_layer >= 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment