diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index d9d3181abbdd5ab288d1e3c81f7d8be31fd59969..3b4ae73cc6609178af67564198be9b15038ab9d2 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -158,7 +158,7 @@ class SelfAttention(torch.nn.Module): if self.training: output = self.output_dropout(output) - return output, None + return output class CrossAttention(torch.nn.Module):