diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py index 9b996fb645dfbe485b65aa0f588e89ef65f4449a..88b5c41f20eed71a98bf00d9eaf0e9dcc0110d03 100644 --- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py +++ b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py @@ -63,13 +63,14 @@ class BeamSearchStrategy: next_token_scores = next_token_scores.view(batch_size * vocab_size) - probs = F.softmax(next_token_scores) - next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) # [2*nb] + probs = F.softmax(next_token_scores, dim=0) + next_tokens = torch.multinomial(probs, + num_samples=(max(1,len(self.end_tokens))+1) * self.num_beams) # [2*nb] next_token_scores = next_token_scores[next_tokens] next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0) next_tokens = next_tokens[_indices] - next_indices = next_tokens // vocab_size # batch idx + next_indices = torch.div(next_tokens, vocab_size, rounding_mode='trunc') next_tokens = next_tokens % vocab_size # select out end beams or continue beams @@ -83,7 +84,7 @@ class BeamSearchStrategy: beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1])) if int(next_tokens[i]) in self.end_tokens: self._add_end_beams(next_token_scores[i], beam) - elif len(beam_continue) < batch_size: + elif len(beam_continue) < self.num_beams: beam_continue.append(beam) mems_contiue.append(mems[:, next_indices[i]]) # update caches