rau.generation

Given an initial state with multiple batch elements, generate a sequence for each element using beam search.

This includes length normalization. That is, for each timestep of generation, when selecting the top beam_size hypotheses for the next beam, we divide the log probability of each hypothesis by the number of symbols in the hypothesis so far (including EOS), and we select the beam_size hypotheses with the highest scores.

Parameters:
  • initial_state (State) – The initial state from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements.

  • beam_size (int) – The maximum number of elements allowed on the beam.

  • eos_symbol (int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.

  • max_length (int) – A hard upper limit on the number of symbols in the generated sequences. If the limit is reached, decoding will start from the highest-scoring beam element at the last timestep.

  • device (device) – A device used to evaluate the model.

Return type:

list[list[int]]

Returns:

A list of decoded sequences, one per batch element in the initial state.