API for torch_semiring_einsum
¶

torch_semiring_einsum.
compile_equation
(equation)¶ Precompile an einsum equation for use with the einsum functions in this package.
 Parameters
equation (
str
) – An equation in einsum syntax. Return type
Equation
 Returns
A precompiled equation.

torch_semiring_einsum.
real_einsum_forward
(equation, *args, block_size)¶ Einsum where addition and multiplication have their usual meanings.
This function has different memory and runtime characteristics than
torch.einsum()
, which can be tuned withblock_size
. Higher values ofblock_size
result in faster runtime and higher memory usage.In some cases, when dealing with summations over more than two input tensors at once, this implementation can have better space complexity than
torch.einsum()
, because it does not create intermediate tensors whose sizes are proportional to the dimensions being summed over.

torch_semiring_einsum.
real_einsum_backward
(equation, args, needs_grad, grad, block_size)¶ Compute the derivative of
real_einsum_forward()
.Like the forward pass, the backward pass is done in memoryefficient fashion by doing summations in fixedsize chunks.
 Parameters
equation (
Equation
) – Precompiled einsum equation. The derivative of the einsum operation specified by this equation will be computed.args (
Sequence
[Tensor
]) – The inputs to the einsum operation whose derivative is being computed.needs_grad (
Sequence
[bool
]) – Indicates which inputs inargs
require gradient.grad (
Tensor
) – The gradient of the loss function with respect to the output of the einsum operation.block_size (
int
) – Block size used to control memory usage.
 Return type
 Returns
The gradients with respect to each of the inputs to the einsum operation. Returns
None
for inputs that do not require gradient.

torch_semiring_einsum.
einsum
(equation, *args, block_size)¶ Differentiable version of ordinary (real) einsum.
This combines
real_einsum_forward()
andreal_einsum_backward()
into one autodifferentiable function.

torch_semiring_einsum.
log_einsum_forward
(equation, *args, block_size)¶ Logspace einsum, where addition \(a + b\) is replaced with \(\log(\exp a + \exp b)\), and multiplication \(a \times b\) is replaced with addition \(a + b\).

torch_semiring_einsum.
log_einsum_backward
(equation, args, needs_grad, grad, block_size)¶ Compute the derivative of
log_einsum_forward()
.Like the forward pass, the backward pass is done in memoryefficient fashion by doing summations in fixedsize chunks.
 Parameters
equation (
Equation
) – Precompiled einsum equation. The derivative of the logspace einsum operation specified by this equation will be computed.args (
Sequence
[Tensor
]) – The inputs to the logspace einsum operation whose derivative is being computed.needs_grad (
Sequence
[bool
]) – Indicates which inputs inargs
require gradient.grad (
Tensor
) – The gradient of the loss function with respect to the output of the logspace einsum operation.block_size (
int
) – Block size used to control memory usage.
 Return type
 Returns
The gradients with respect to each of the inputs to the logspace einsum operation. Returns
None
for inputs that do not require gradient.

torch_semiring_einsum.
log_einsum
(equation, *args, block_size)¶ Differentiable version of logspace einsum.
This combines
log_einsum_forward()
andlog_einsum_backward()
into one autodifferentiable function.

torch_semiring_einsum.
log_viterbi_einsum_forward
(equation, *args, block_size)¶ Viterbi einsum, where addition \(a + b\) is replaced with \((\max(a, b), \arg \max(a, b))\), and multiplication \(a \times b\) is replaced with logspace multiplication \(a + b\).
 Parameters
 Return type
 Returns
A tuple containing the max and argmax of the einsum operation. The first element of the tuple simply contains the maximum values of the terms “summed” over by einsum. The second element contains the values of the summedout variables that maximized those terms. If the max tensor has dimension \(N_1 \times \cdots \times N_m\), and \(k\) variables were summed out, then the argmax tensor has dimension \(N_1 \times \cdots \times N_m \times k\), where the \((m+1)\)th dimension is a \(k\)tuple of indexes representing the argmax. The variables in the ktuple are ordered by first appearance in the einsum equation.

torch_semiring_einsum.
semiring_einsum_forward
(equation, args, block_size, func)¶ Implement a custom version of einsum using the callback
func
.This function is the main workhorse used to implement einsum for different semirings. It takes away the burden of figuring out how to index the input tensors and sum terms in a memoryefficient way, and only requires callbacks for performing addition and multiplication. It is also flexible enough to support multiple passes through the input tensors (this feature is required for implementing logsumexp). This function is used internally by the real, log, and Viterbi semiring einsum implementations in this package and can be used to implement einsum in other semirings as well.
Note that this function only implements the forward aspect of einsum and is not differentiable. To turn your instantiation of einsum in a particular semiring into a differentiable PyTorch
Function
, implement its derivative and usecombine()
to combine the forward and backward functions into one function. Odds are,semiring_einsum_forward()
can be used to implement the derivative efficiently as well (despite including “forward” in the name, there is nothing preventing you from using it as a tool in the backward step).semiring_einsum_forward()
will callfunc
asfunc(compute_sum)
, wherecompute_sum
is itself another function. Callingcompute_sum
executes a single einsum pass over the input tensors, where you supply custom functions for addition and multiplication; this is where the semiring customization really takes place.semiring_einsum_forward()
returns whatever you return fromfunc
, which will usually be what is returned fromcompute_sum
.func
will often consist of a single call tocompute_sum()
, but there are cases where multiple passes over the inputs with different semirings is useful (e.g. for a numerically stable logsumexp implementation, one must first compute maximum values and then use them for a subsequent logsumexp step).Here is a quick example that implements the equivalent of
torch.einsum()
:def regular_einsum(equation, *args, block_size): def func(compute_sum): def add_in_place(a, b): a += b def sum_block(a, dims): if dims: return torch.sum(a, dim=dims) else: # This is an edge case that `torch.sum` does not # handle correctly. return a def multiply_in_place(a, b): a *= b return compute_sum(add_in_place, sum_block, multiply_in_place) return semiring_einsum_forward(equation, args, block_size, func)
The full signature of
compute_sum
iscompute_sum(add_in_place, sum_block, multiply_in_place, include_indexes=False)
. The+
and*
operators are customized usingadd_in_place
,sum_block
, andmultiply_in_place
.add_in_place(a, b)
must be a function that accepts two values and implementsa += b
for the desired definition of+
. Likewise,multiply_in_place(a, b)
must implementa *= b
for the desired definition of*
. The argumentsa
andb
are values returned fromsum_block
(see below) and are usually of typeTensor
, although they can be something fancier for cases like Viterbi (which involves a pair of tensors: max and argmax). These functions must modify the objecta
inplace; the return value is ignored.sum_block(a, dims)
should be a function that “sums” over multiple dimensions in a tensor at once. It must return its result.a
is always aTensor
.dims
is atuple
ofint
s representing the dimensions ina
to sum out. Take special care to handle the case wheredims
is an empty tuple – in particular, keep in mind thattorch.sum()
returns a scalar whendims
is an empty tuple. Simply returninga
is sufficient to handle this edge case. Note that it is always safe to return a view ofa
fromsum_block
, sincea
is itself never a view of the input tensors, but always a new tensor.If
include_indexes
isTrue
, thensum_block
will receive a third argument,var_values
, which contains the current indexes of the parts of the input tensors being summed over (sum_block
is called multiple times on different slices of the inputs).var_values
is atuple
ofrange
objects representing the ranges of indexes representing the current slice.var_values
contains an entry for each summed variable, in order of first appearance in the equation. Parameters
equation (
Equation
) – A precompiled equation.block_size (
int
) – To keep memory usage in check, the einsum summation is done over multiple “windows” or “blocks” of bounded size. This parameter sets the maximum size of these windows. More precisely, it defines the maximum size of the range of values of each summed variable that is included in a single window. If there aren
summed variables, the size of the window tensor is proportional toblock_size ** n
.func (
Callable
) – A callback of the form described above.

torch_semiring_einsum.
combine
(forward, backward)¶ Combine an einsum implementation and its derivative into a single function that works with PyTorch’s autograd mechanics.
Combining separate forward and backward implementations allows more memory efficiency than would otherwise be possible.
 Parameters
forward (
Callable
) – The forward implementation.backward (
Callable
) – The backward implementation. Its signature must bebackward(equation, args, needs_grad, grad, block_size)
, and it must return atuple
ofTensor
containing the gradients with respect toargs
. The \(i\)th gradient may beNone
ifneeds_grad[i]
isFalse
.
 Return type
 Returns
A function whose return value is compatible with PyTorch’s autograd mechanics.