Skip to content
Snippets Groups Projects
Commit 87a3bedf authored by Ming Ding's avatar Ming Ding
Browse files

fix predict many end_tokens bug

parent 69cbe5dd
No related branches found
No related tags found
No related merge requests found
......@@ -63,13 +63,14 @@ class BeamSearchStrategy:
next_token_scores = next_token_scores.view(batch_size * vocab_size)
probs = F.softmax(next_token_scores)
next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) # [2*nb]
probs = F.softmax(next_token_scores, dim=0)
next_tokens = torch.multinomial(probs,
num_samples=(max(1,len(self.end_tokens))+1) * self.num_beams) # [2*nb]
next_token_scores = next_token_scores[next_tokens]
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
next_tokens = next_tokens[_indices]
next_indices = next_tokens // vocab_size # batch idx
next_indices = torch.div(next_tokens, vocab_size, rounding_mode='trunc')
next_tokens = next_tokens % vocab_size
# select out end beams or continue beams
......@@ -83,7 +84,7 @@ class BeamSearchStrategy:
beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1]))
if int(next_tokens[i]) in self.end_tokens:
self._add_end_beams(next_token_scores[i], beam)
elif len(beam_continue) < batch_size:
elif len(beam_continue) < self.num_beams:
beam_continue.append(beam)
mems_contiue.append(mems[:, next_indices[i]])
# update caches
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment