Skip to content
Snippets Groups Projects
Commit 00442c04 authored by duzx16's avatar duzx16
Browse files

Add requires_grad at the beginning

parent be349db4
No related branches found
No related tags found
No related merge requests found
...@@ -527,6 +527,8 @@ class BaseTransformer(torch.nn.Module): ...@@ -527,6 +527,8 @@ class BaseTransformer(torch.nn.Module):
l, num_layers = 0, len(self.layers) l, num_layers = 0, len(self.layers)
chunk_length = self.checkpoint_num_layers chunk_length = self.checkpoint_num_layers
if self.training:
hidden_states.requires_grad_(True)
while l < num_layers: while l < num_layers:
if branch_input is not None: if branch_input is not None:
args = [hidden_states, attention_mask, branch_input] args = [hidden_states, attention_mask, branch_input]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment