You're reading the documentation for a development version. For the latest released version, please have a look at v0.2.0.
rau.generation¶
This module includes algorithms for sampling strings of symbols from
autoregressive language models or decoders which are encapsulated in
States.
- rau.generation.sample(initial_state, eos_symbol, num_samples=1, max_length=None, generator=None)¶
Given a state of an autoregressive language model or decoder containing any number of batch elements, randomly generate
num_samplessequences for each element using ancestral sampling. Sampling is parallelized across batch elements and samples.- Parameters:
initial_state (
State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate list ofnum_samplessequences will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.eos_symbol (
int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.num_samples (
int) – Number of samples per batch element.max_length (
int|None) – A hard upper limit on the number of symbols in the generated sequences.generator (
Generator|None) – Optional random number generator to make sampling deterministic.
- Return type:
- Returns:
Randomly generated sequences for each batch element. This will be a list of lists of sequences, where each batch element has a list containing
num_samplessequences.
- rau.generation.decode_greedily(initial_state, eos_symbol, max_length=None)¶
Given a state of an autoregressive language model or decoder containing any number of batch elements, generate a sequence for each element using greedy decoding. Decoding is parallelized across batch elements.
- Parameters:
initial_state (
State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.eos_symbol (
int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.max_length (
int|None) – A hard upper limit on the number of symbols in the generated sequences.
- Return type:
- Returns:
A list of generated sequences, one per batch element in the initial state.
- rau.generation.beam_search(initial_state, beam_size, eos_symbol, max_length, device)¶
Given a state of an autoregressive language model or decoder containing any number of batch elements, generate a sequence for each element using beam search.
This includes length normalization. That is, for each timestep of generation, when selecting the top
beam_sizehypotheses for the next beam, we divide the (unnormalized) log probability of each hypothesis by the number of symbols in the hypothesis so far (including EOS), and we select thebeam_sizehypotheses with the highest scores.- Parameters:
initial_state (
State) – A state of an autoregressive decoder or language model from which decoding starts, containing any number of batch elements. A separate sequence will be decoded for each of the initial batch elements. Note that this does not actually need to be the initial state of a decoder; decoding can start from any state.beam_size (
int) – The maximum number of elements allowed on the beam.eos_symbol (
int) – Identifier of a designated end-of-sequence (EOS) symbol that indicates that the model should stop generating symbols for a sequence.max_length (
int) – A hard upper limit on the number of symbols in the generated sequences. If the limit is reached, decoding will start from the highest-scoring beam element at the last timestep.device (
device) – The device where intermediate data (log probabilities, backpointers) will be stored.
- Return type:
- Returns:
A list of decoded sequences, one per batch element in the initial state.