Composable Neural Networks

For those using the Python API of Rau, a useful feature that the library provides is the ability to easily create new neural network modules by composing simpler modules with the | operator, so that the output of one is used as the input to the other. (The choice of | as the composition operator is meant to evoke piping from shell languages.) If A and B are Modules and A is also an instance of Rau’s BasicComposable class, then the expression A | B creates a new Module whose () operator passes its input to A, feeds the output of A as input to B, and returns the output of B. This module is also an instance of BasicComposable, so you can easily create a pipeline of more than two modules like A | B | C | D | .... You can make any Module an instance of BasicComposable by wrapping it in Composable.

import torch
from torch.nn import Linear
from rau.tools.torch.compose import Composable

# Create a simple pipeline of Linear modules.
# We only need to wrap the first module in Composable to kick
# off composition.
# Note that the sizes of connecting linear layers match.
M = Composable(Linear(3, 7)) | Linear(7, 5) | Linear(5, 11)

# Feed an input to the composed module.
x = torch.ones(3)
y = M(x)

# The output is the size of the output from the last module.
assert y.size() == (11,)

This saves you the trouble of defining a custom Module subclass that implements this pipeline.

The full API is documented in rau.tools.torch.compose.

Composable Sequential Neural Networks

This composition feature is especially useful when dealing with sequential neural networks in Rau. Rau uses an abstraction for sequential neural networks called Unidirectional. A Unidirectional is a Module that receives a variable-length sequence of Tensors as input and produces an output Tensor for each input Tensor. Moreover, each output Tensor must not have any data dependencies on future inputs. As usual, a Unidirectional has a () operator, which receives the inputs stacked into a single Tensor along dimension 1 (the batch dimension is 0) and returns the outputs similarly stacked into a single Tensor.

import torch
from rau.models.transformer.unidirectional_encoder import (
    get_unidirectional_transformer_encoder
)

# This instantiates a causally-masked transformer encoder (also
# known as a "decoder-only" transformer). It is an instance of
# Unidirectional.
M = get_unidirectional_transformer_encoder(
    # This module will receive a sequence of tensors of size 5 as
    # input.
    input_vocabulary_size=5,
    # This module will produce a sequence of tensors of size 3 as
    # output.
    output_vocabulary_size=3,
    # Turn off dropout in order ot make the outputs deterministic
    # for this example.
    dropout=0,
    # The remaining arguments are not relevant for this example.
    tie_embeddings=False,
    num_layers=5,
    d_model=32,
    num_heads=4,
    feedforward_size=64,
    use_padding=False
)
# Batch size.
B = 7
# Sequence length.
n = 11

# Create a batch of sequences of integer inputs in the range [0, 5)
# of length n. These are the "tokens" given to the transformer
# encoder.
x = torch.randint(5, (B, n))

# Use the () operator to get an output sequence of vectors.
# The argument include_first=False tells the module that we do not
# want it to attempt to produce an output before reading the first
# input. This is not possible for transformers, but it is for RNNs,
# which have an initial hidden state. For transformers, an output
# corresponding to an initial BOS input can serve the same purpose,
# but the BOS would need to be added to the input x, which we have
# not done in this example.
y = M(x, include_first=False)
assert y.size() == (B, n, 3)

It also has an initial_state() method that returns a State object, which can be used to receive inputs and return outputs iteratively using its next and output methods.

from torch.testing import assert_close

state = M.initial_state(batch_size=B)
# Call .next() to feed a new input to the current state and produce
# a new state.
state = state.next(x[:, 0])
# Call .output() to get the output tensor of this state.
# Because transformers have no initial output vector before reading
# any inputs, calling .output() before .next() would have raised an
# error.
y1 = state.output()
# The output of this state is a single vector of size 3 and is
# equivalent to the first element of the output of ().
assert y1.size() == (B, 3)
assert_close(y1, y[:, 0])
# Do the same thing for a second iteration.
state = state.next(x[:, 1])
y2 = state.output()
assert y2.size() == (B, 3)
assert_close(y2, y[:, 1])

These two modes are useful in different scenarios. The () method can be overridden to parallelize computation across the sequence dimension, making it more efficient than the iterative mode. This makes the () method useful for training, where future inputs are always known in advance. The iterative mode is useful when future inputs are not known in advance, namely when generating sequences from language models or decoders in machine translation systems.

Unidirectionals can also be composed with the | operator. If A and B are both Unidirectionals, then the expression A | B returns another Unidirectional that feeds its inputs to A, feeds the outputs of A as inputs to B, and returns the outputs of B. Like A and B, the Unidirectional returned by A | B also supports both () and iterative modes. If A and B implement their () and iterative modes efficiently, then A | B gives you a composed module that implements both modes efficiently for free.

The full API is documented in rau.unidirectional.

Argument Routing

What if you try to compose modules that require multiple arguments? For example, if you have a module A that takes no keyword arguments, a module B that requires a keyword argument foo, and a module C that requires keyword arguments bar and baz, how do you invoke A | B | C? Rau handles this by allowing you to add tags to modules that signal which modules should receive which arguments.

# Create a pipeline where individual modules have been tagged.
M = A | B.tag('b') | C.tag('c')
x = torch.rand(B, n, A_input_size)
y = M(
    # x will be passed as input to A, whose output will be passed
    # as input B, whose output will be passed as input to C, whose
    # output will be returned as y.
    x,
    # tag_kwargs is a dict that maps tags to dicts of keyword
    # arguments. The keyword argument foo=123 will be passed to B,
    # and the keywords bar=456 and baz=789 will be passed to C.
    tag_kwargs=dict(
        b=dict(foo=123),
        c=dict(
            bar=456,
            baz=789
        )
    )
)

You can make this more succinct by designating at most one module in a pipeline as the “main” module, which will receive any extra positional or keyword arguments. This is useful when wrapping a module with input and output layers.

# Create a pipeline where B is tagged with 'b' and C is the main
# module.
M = A | B.tag('b') | C.main()
x = torch.rand(B, n, A_input_size)
y = M(
    x,
    bar=456,
    baz=789,
    tag_kwargs=dict(
        b=dict(foo=123)
    )
)