Skip to content
Snippets Groups Projects
Commit 51e1a513 authored by duzx16's avatar duzx16
Browse files

Add non_conflict for T5 attention

fix position bias for memory in T5
parent c17f8035
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
from .mixins import BaseMixin
from .encoder_decoder_model import EncoderDecoderModel
from .base_model import non_conflict
from SwissArmyTransformer.mpu import get_model_parallel_world_size
from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
......@@ -111,6 +112,7 @@ class T5AttentionMixin(BaseMixin):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
cross_attention=False, **kw_args):
log_attention_weights = None
......@@ -118,7 +120,8 @@ class T5AttentionMixin(BaseMixin):
if position_bias is None:
seq_length = q.size(2)
key_length = k.size(2)
position_bias = self.compute_bias(seq_length, key_length)
position_bias = self.compute_bias(key_length, key_length)
position_bias = position_bias[:, :, -seq_length:, :]
kw_args['output_cross_layer']['position_bias'] = position_bias
log_attention_weights = position_bias
return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment