From c17f8035011d13770716bc20d50311fc09477904 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Mon, 13 Dec 2021 15:16:13 +0800 Subject: [PATCH] Delete unused code --- SwissArmyTransformer/model/t5_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index e0d2dc6..31f38e0 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -95,7 +95,7 @@ class T5AttentionMixin(BaseMixin): relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, cross_attention=False): + def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] @@ -107,10 +107,7 @@ class T5AttentionMixin(BaseMixin): ) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) # shape (query_length, key_length, num_heads) - if cross_attention: - values = self.cross_relative_attention_bias(relative_position_bucket) - else: - values = self.relative_attention_bias(relative_position_bucket) + values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values -- GitLab