From 7aee8e24c61a435416c163aede7a408f52a9efb8 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Wed, 1 Dec 2021 19:13:43 +0800
Subject: [PATCH] Reformat code

---
 SwissArmyTransformer/model/mixins.py | 38 +++++++++++++++-------------
 1 file changed, 21 insertions(+), 17 deletions(-)

diff --git a/SwissArmyTransformer/model/mixins.py b/SwissArmyTransformer/model/mixins.py
index 2a76b80..2c6b099 100644
--- a/SwissArmyTransformer/model/mixins.py
+++ b/SwissArmyTransformer/model/mixins.py
@@ -18,40 +18,44 @@ from SwissArmyTransformer.mpu.transformer import unscaled_init_method
 from .base_model import BaseMixin
 from .cached_autoregressive_model import CachedAutoregressiveMixin
 
+
 class PositionEmbeddingMixin(BaseMixin):
-    def __init__(self, additional_sequence_length, hidden_size, 
-                init_method_std=0.02, reinit_slice=slice(-1024, None)
-        ):
+    def __init__(self, additional_sequence_length, hidden_size,
+                 init_method_std=0.02, reinit_slice=slice(-1024, None)
+                 ):
         super(PositionEmbeddingMixin, self).__init__()
         self.reinit_slice = reinit_slice
         self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
+
     def reinit(self, *pre_mixins):
         old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
         old_len, hidden_size = old_weights.shape
         assert hidden_size == self.position_embeddings.weight.shape[-1]
         self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
 
+
 class AttentionMixin(BaseMixin):
     def __init__(self, num_layers,
-                hidden_size, 
-                init_method=unscaled_init_method(0.02),
-                output_layer_init_method=unscaled_init_method(0.02)
-        ):
+                 hidden_size,
+                 init_method=unscaled_init_method(0.02),
+                 output_layer_init_method=unscaled_init_method(0.02)
+                 ):
         super(AttentionMixin, self).__init__()
-        self.num_layers = num_layers # replace attention in the LAST n layers
+        self.num_layers = num_layers  # replace attention in the LAST n layers
         self.query_key_value = torch.nn.ModuleList(
-            [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
-                gather_output=False,init_method=init_method)
-                for layer_id in range(num_layers)
-            ])
+            [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
+                                  gather_output=False, init_method=init_method)
+             for layer_id in range(num_layers)
+             ])
         self.dense = torch.nn.ModuleList(
             [RowParallelLinear(hidden_size,
-                hidden_size,
-                input_is_parallel=True,
-                init_method=output_layer_init_method)
-                for layer_id in range(num_layers)
-            ])
+                               hidden_size,
+                               input_is_parallel=True,
+                               init_method=output_layer_init_method)
+             for layer_id in range(num_layers)
+             ])
+
     def reinit(self, *pre_mixins):
         start_layer = len(self.transformer.layers) - self.num_layers
         assert start_layer >= 0
-- 
GitLab