From c17f8035011d13770716bc20d50311fc09477904 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Mon, 13 Dec 2021 15:16:13 +0800
Subject: [PATCH] Delete unused code

---
 SwissArmyTransformer/model/t5_model.py | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index e0d2dc6..31f38e0 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -95,7 +95,7 @@ class T5AttentionMixin(BaseMixin):
         relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
         return relative_buckets
 
-    def compute_bias(self, query_length, key_length, cross_attention=False):
+    def compute_bias(self, query_length, key_length):
         """Compute binned relative position bias"""
         context_position = torch.arange(query_length, dtype=torch.long)[:, None]
         memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
@@ -107,10 +107,7 @@ class T5AttentionMixin(BaseMixin):
         )
         relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
         # shape (query_length, key_length, num_heads)
-        if cross_attention:
-            values = self.cross_relative_attention_bias(relative_position_bucket)
-        else:
-            values = self.relative_attention_bias(relative_position_bucket)
+        values = self.relative_attention_bias(relative_position_bucket)
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
-- 
GitLab