rau.generation

This module includes algorithms for sampling strings of symbols from autoregressive language models or decoders which are encapsulated in States.

rau.generation.sample(initial_state, eos_symbol, max_length, generator=None)

Given a state of an autoregressive language model or decoder containing any number of batch elements, generate a sequence for each element using ancestral sampling.

Parameters:
  • initial_state (State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.

  • eos_symbol (int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.

  • max_length (int) – A hard upper limit on the number of symbols in the generated sequences.

  • generator (Generator | None) – Optional random number generator to make sampling deterministic.

Return type:

list[list[int]]

Returns:

A list of generated sequences, one per batch element in the initial state.

rau.generation.decode_greedily(initial_state, eos_symbol, max_length)

Given a state of an autoregressive language model or decoder containing any number of batch elements, generate a sequence for each element using greedy decoding.

Parameters:
  • initial_state (State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.

  • eos_symbol (int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.

  • max_length (int) – A hard upper limit on the number of symbols in the generated sequences.

Return type:

list[list[int]]

Returns:

A list of generated sequences, one per batch element in the initial state.

Given a state of an autoregressive language model or decoder containing any number of batch elements, generate a sequence for each element using beam search.

This includes length normalization. That is, for each timestep of generation, when selecting the top beam_size hypotheses for the next beam, we divide the (unnormalized) log probability of each hypothesis by the number of symbols in the hypothesis so far (including EOS), and we select the beam_size hypotheses with the highest scores.

Parameters:
  • initial_state (State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.

  • beam_size (int) – The maximum number of elements allowed on the beam.

  • eos_symbol (int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.

  • max_length (int) – A hard upper limit on the number of symbols in the generated sequences. If the limit is reached, decoding will start from the highest-scoring beam element at the last timestep.

  • device (device) – The device where intermediate data (log probabilities, backpointers) will be stored.

Return type:

list[list[int]]

Returns:

A list of decoded sequences, one per batch element in the initial state.