rau.generation
¶
- rau.generation.beam_search(initial_state, beam_size, eos_symbol, max_length, device)¶
Given an initial state with multiple batch elements, generate a sequence for each element using beam search.
This includes length normalization. That is, for each timestep of generation, when selecting the top
beam_size
hypotheses for the next beam, we divide the log probability of each hypothesis by the number of symbols in the hypothesis so far (including EOS), and we select thebeam_size
hypotheses with the highest scores.- Parameters:
initial_state (
State
) – The initial state from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements.beam_size (
int
) – The maximum number of elements allowed on the beam.eos_symbol (
int
) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.max_length (
int
) – A hard upper limit on the number of symbols in the generated sequences. If the limit is reached, decoding will start from the highest-scoring beam element at the last timestep.device (
device
) – A device used to evaluate the model.
- Return type:
- Returns:
A list of decoded sequences, one per batch element in the initial state.