diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index e0d2dc6d4f9776f49582beead4089ce0a63c42e5..31f38e0d1bc20229e97d4efdec33bb91f1d4c14d 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -95,7 +95,7 @@ class T5AttentionMixin(BaseMixin): relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, cross_attention=False): + def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] @@ -107,10 +107,7 @@ class T5AttentionMixin(BaseMixin): ) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) # shape (query_length, key_length, num_heads) - if cross_attention: - values = self.cross_relative_attention_bias(relative_position_bucket) - else: - values = self.relative_attention_bias(relative_position_bucket) + values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values