diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
index 9b996fb645dfbe485b65aa0f588e89ef65f4449a..88b5c41f20eed71a98bf00d9eaf0e9dcc0110d03 100644
--- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
+++ b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
@@ -63,13 +63,14 @@ class BeamSearchStrategy:
         
         next_token_scores = next_token_scores.view(batch_size * vocab_size)
 
-        probs = F.softmax(next_token_scores)
-        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) # [2*nb]
+        probs = F.softmax(next_token_scores, dim=0)
+        next_tokens = torch.multinomial(probs, 
+            num_samples=(max(1,len(self.end_tokens))+1) * self.num_beams) # [2*nb]
         next_token_scores = next_token_scores[next_tokens]
         next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
         next_tokens = next_tokens[_indices]
 
-        next_indices = next_tokens // vocab_size # batch idx
+        next_indices = torch.div(next_tokens, vocab_size, rounding_mode='trunc')
         next_tokens = next_tokens % vocab_size
 
         # select out end beams or continue beams
@@ -83,7 +84,7 @@ class BeamSearchStrategy:
             beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1]))
             if int(next_tokens[i]) in self.end_tokens:
                 self._add_end_beams(next_token_scores[i], beam)
-            elif len(beam_continue) < batch_size:
+            elif len(beam_continue) < self.num_beams:
                 beam_continue.append(beam)
                 mems_contiue.append(mems[:, next_indices[i]])
                 # update caches