API for torch_semiring_einsum

torch_semiring_einsum.compile_equation(equation)

Pre-compile an einsum equation for use with the einsum functions in this package.

Parameters

equation (str) – An equation in einsum syntax.

Return type

Equation

Returns

A pre-compiled equation.

torch_semiring_einsum.real_einsum_forward(equation, *args, block_size)

Einsum where addition and multiplication have their usual meanings.

This function has different memory and runtime characteristics than torch.einsum(), which can be tuned with block_size. Higher values of block_size result in faster runtime and higher memory usage.

In some cases, when dealing with summations over more than two input tensors at once, this implementation can have better space complexity than torch.einsum(), because it does not create intermediate tensors whose sizes are proportional to the dimensions being summed over.

Parameters
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (int) – Block size used to control memory usage.

Return type

Tensor

Returns

Output of einsum.

torch_semiring_einsum.real_einsum_backward(equation, args, needs_grad, grad, block_size)

Compute the derivative of real_einsum_forward().

Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.

Parameters
  • equation (Equation) – Pre-compiled einsum equation. The derivative of the einsum operation specified by this equation will be computed.

  • args (Sequence[Tensor]) – The inputs to the einsum operation whose derivative is being computed.

  • needs_grad (Sequence[bool]) – Indicates which inputs in args require gradient.

  • grad (Tensor) – The gradient of the loss function with respect to the output of the einsum operation.

  • block_size (int) – Block size used to control memory usage.

Return type

List[Optional[Tensor]]

Returns

The gradients with respect to each of the inputs to the einsum operation. Returns None for inputs that do not require gradient.

torch_semiring_einsum.einsum(equation, *args, block_size)

Differentiable version of ordinary (real) einsum.

This combines real_einsum_forward() and real_einsum_backward() into one auto-differentiable function.

torch_semiring_einsum.log_einsum_forward(equation, *args, block_size)

Log-space einsum, where addition \(a + b\) is replaced with \(\log(\exp a + \exp b)\), and multiplication \(a \times b\) is replaced with addition \(a + b\).

Parameters
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (int) – Block size used to control memory usage.

Return type

Tensor

Returns

Output of einsum.

torch_semiring_einsum.log_einsum_backward(equation, args, needs_grad, grad, block_size)

Compute the derivative of log_einsum_forward().

Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.

Parameters
  • equation (Equation) – Pre-compiled einsum equation. The derivative of the log-space einsum operation specified by this equation will be computed.

  • args (Sequence[Tensor]) – The inputs to the log-space einsum operation whose derivative is being computed.

  • needs_grad (Sequence[bool]) – Indicates which inputs in args require gradient.

  • grad (Tensor) – The gradient of the loss function with respect to the output of the log-space einsum operation.

  • block_size (int) – Block size used to control memory usage.

Return type

List[Optional[Tensor]]

Returns

The gradients with respect to each of the inputs to the log-space einsum operation. Returns None for inputs that do not require gradient.

torch_semiring_einsum.log_einsum(equation, *args, block_size)

Differentiable version of log-space einsum.

This combines log_einsum_forward() and log_einsum_backward() into one auto-differentiable function.

torch_semiring_einsum.log_viterbi_einsum_forward(equation, *args, block_size)

Viterbi einsum, where addition \(a + b\) is replaced with \((\max(a, b), \arg \max(a, b))\), and multiplication \(a \times b\) is replaced with log-space multiplication \(a + b\).

Parameters
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (int) – Block size used to control memory usage.

Return type

Tuple[Tensor, LongTensor]

Returns

A tuple containing the max and argmax of the einsum operation. The first element of the tuple simply contains the maximum values of the terms “summed” over by einsum. The second element contains the values of the summed-out variables that maximized those terms. If the max tensor has dimension \(N_1 \times \cdots \times N_m\), and \(k\) variables were summed out, then the argmax tensor has dimension \(N_1 \times \cdots \times N_m \times k\), where the \((m+1)\)th dimension is a \(k\)-tuple of indexes representing the argmax. The variables in the k-tuple are ordered by first appearance in the einsum equation.

torch_semiring_einsum.semiring_einsum_forward(equation, args, block_size, func)

Implement a custom version of einsum using the callback func.

This function is the main workhorse used to implement einsum for different semirings. It takes away the burden of figuring out how to index the input tensors and sum terms in a memory-efficient way, and only requires callbacks for performing addition and multiplication. It is also flexible enough to support multiple passes through the input tensors (this feature is required for implementing logsumexp). This function is used internally by the real, log, and Viterbi semiring einsum implementations in this package and can be used to implement einsum in other semirings as well.

Note that this function only implements the forward aspect of einsum and is not differentiable. To turn your instantiation of einsum in a particular semiring into a differentiable PyTorch Function, implement its derivative and use combine() to combine the forward and backward functions into one function. Odds are, semiring_einsum_forward() can be used to implement the derivative efficiently as well (despite including “forward” in the name, there is nothing preventing you from using it as a tool in the backward step).

semiring_einsum_forward() will call func as func(compute_sum), where compute_sum is itself another function. Calling compute_sum executes a single einsum pass over the input tensors, where you supply custom functions for addition and multiplication; this is where the semiring customization really takes place. semiring_einsum_forward() returns whatever you return from func, which will usually be what is returned from compute_sum. func will often consist of a single call to compute_sum(), but there are cases where multiple passes over the inputs with different semirings is useful (e.g. for a numerically stable logsumexp implementation, one must first compute maximum values and then use them for a subsequent logsumexp step).

Here is a quick example that implements the equivalent of torch.einsum():

def regular_einsum(equation, *args, block_size):
    def func(compute_sum):
        def add_in_place(a, b):
            a += b
        def sum_block(a, dims):
            if dims:
                return torch.sum(a, dim=dims)
            else:
                # This is an edge case that `torch.sum` does not
                # handle correctly.
                return a
        def multiply_in_place(a, b):
            a *= b
        return compute_sum(add_in_place, sum_block, multiply_in_place)
    return semiring_einsum_forward(equation, args, block_size, func)

The full signature of compute_sum is compute_sum(add_in_place, sum_block, multiply_in_place, include_indexes=False). The + and * operators are customized using add_in_place, sum_block, and multiply_in_place.

add_in_place(a, b) must be a function that accepts two values and implements a += b for the desired definition of +. Likewise, multiply_in_place(a, b) must implement a *= b for the desired definition of *. The arguments a and b are values returned from sum_block (see below) and are usually of type Tensor, although they can be something fancier for cases like Viterbi (which involves a pair of tensors: max and argmax). These functions must modify the object a in-place; the return value is ignored.

sum_block(a, dims) should be a function that “sums” over multiple dimensions in a tensor at once. It must return its result. a is always a Tensor. dims is a tuple of ints representing the dimensions in a to sum out. Take special care to handle the case where dims is an empty tuple – in particular, keep in mind that torch.sum() returns a scalar when dims is an empty tuple. Simply returning a is sufficient to handle this edge case. Note that it is always safe to return a view of a from sum_block, since a is itself never a view of the input tensors, but always a new tensor.

If include_indexes is True, then sum_block will receive a third argument, var_values, which contains the current indexes of the parts of the input tensors being summed over (sum_block is called multiple times on different slices of the inputs). var_values is a tuple of range objects representing the ranges of indexes representing the current slice. var_values contains an entry for each summed variable, in order of first appearance in the equation.

Parameters
  • equation (Equation) – A pre-compiled equation.

  • args (Sequence[Tensor]) – A list of input tensors.

  • block_size (int) – To keep memory usage in check, the einsum summation is done over multiple “windows” or “blocks” of bounded size. This parameter sets the maximum size of these windows. More precisely, it defines the maximum size of the range of values of each summed variable that is included in a single window. If there are n summed variables, the size of the window tensor is proportional to block_size ** n.

  • func (Callable) – A callback of the form described above.

torch_semiring_einsum.combine(forward, backward)

Combine an einsum implementation and its derivative into a single function that works with PyTorch’s autograd mechanics.

Combining separate forward and backward implementations allows more memory efficiency than would otherwise be possible.

Parameters
  • forward (Callable) – The forward implementation.

  • backward (Callable) – The backward implementation. Its signature must be backward(equation, args, needs_grad, grad, block_size), and it must return a tuple of Tensor containing the gradients with respect to args. The \(i\)th gradient may be None if needs_grad[i] is False.

Return type

Callable

Returns

A function whose return value is compatible with PyTorch’s autograd mechanics.