rau.tasks
¶
- rau.tasks.common.add_prepare_data_args(parser)¶
- rau.tasks.common.validate_prepare_data_args(parser, args)¶
- rau.tasks.common.get_token_types(tokens, unk_string)¶
- rau.tasks.common.get_token_types_in_file(path, unk_string)¶
- rau.tasks.common.prepare_file(vocab, pair)¶
- rau.tasks.common.pad_sequences(sequences, device, pad, bos=None, eos=None, return_lengths=False)¶
- rau.tasks.common.add_training_loop_arguments(parser, max_tokens_per_batch_help)¶
- class rau.tasks.common.TrainingLoop¶
Bases:
Generic
[Example
,PreparedBatch
,VocabularyContainer
]TrainingLoop(show_progress: bool, max_epochs: int, random_shuffling_seed: int, max_tokens_per_batch: int, optimizer: Literal[‘SGD’, ‘Adam’], initial_learning_rate: float, label_smoothing_factor: float | None, gradient_clipping_threshold: float | None, early_stopping_patience: int, learning_rate_patience: int, learning_rate_decay_factor: float, examples_per_checkpoint: int)
- __init__(show_progress, max_epochs, random_shuffling_seed, max_tokens_per_batch, optimizer, initial_learning_rate, label_smoothing_factor, gradient_clipping_threshold, early_stopping_patience, learning_rate_patience, learning_rate_decay_factor, examples_per_checkpoint)¶
- generate_batches(examples, max_tokens)¶
Given the full list of examples in a dataset and a maximum size, group those examples into minibatches.
- get_loss(model, model_interface, prepared_batch)¶
Return a differentiable tensor representing the loss function to be optimized.
- get_prepared_batch_and_loss(saver, model_interface, batch)¶
- get_validation_metric_mode()¶
Return whether the validation metric is supposed to go up (max) or down (min).
- Return type:
Literal
['min'
,'max'
]
- get_validation_metric_name()¶
Return the name of the validation set metric used for early stopping and learning rate scheduling.
- Return type:
- handle_out_of_cuda_memory(vocabulary, batch, info, device, console_logger, event_logger)¶
- Return type:
- log_failed_batch(vocabulary, batch, info, console_logger, event_logger)¶
- run(saver, model_interface, training_data, validation_data, vocabulary, console_logger, event_logger)¶
NOTE: When this function returns, the model’s parameters will be those of the last epoch, not necessarily the best epoch. However, the saved model will be the best one.
- Return type:
- run_parameter_update(saver, model_interface, optimizer, batch)¶
- rau.tasks.common.get_random_generator_and_seed(random_seed)¶
- rau.tasks.common.get_random_seed(random_seed)¶
- rau.tasks.common.evaluate(model, model_interface, batches, evaluate_batch)¶