API for torch_semiring_einsum
¶
- torch_semiring_einsum.compile_equation(equation)¶
Pre-compile an einsum equation for use with the einsum functions in this package.
- class torch_semiring_einsum.Equation¶
An einsum equation that has been pre-compiled into some useful data structures.
- __init__(source, variable_locations, input_variables, output_variables, num_variables)¶
- torch_semiring_einsum.einsum(equation, *args, block_size, **kwargs)¶
Differentiable version of ordinary (real) einsum.
This combines
real_einsum_forward()
andreal_einsum_backward()
into one auto-differentiable function.
- torch_semiring_einsum.real_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE)¶
Einsum where addition and multiplication have their usual meanings.
This function has different memory and runtime characteristics than
torch.einsum()
, which can be tuned withblock_size
. Higher values ofblock_size
result in faster runtime and higher memory usage.In some cases, when dealing with summations over more than two input tensors at once, this implementation can have better space complexity than
torch.einsum()
, because it does not create intermediate tensors whose sizes are proportional to the dimensions being summed over.
- torch_semiring_einsum.real_einsum_backward(equation, args, needs_grad, grad, block_size=AUTOMATIC_BLOCK_SIZE)¶
Compute the derivative of
real_einsum_forward()
.Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.
- Parameters:
equation (
Equation
) – Pre-compiled einsum equation. The derivative of the einsum operation specified by this equation will be computed.args (
Sequence
[Tensor
]) – The inputs to the einsum operation whose derivative is being computed.needs_grad (
Sequence
[bool
]) – Indicates which inputs inargs
require gradient.grad (
Tensor
) – The gradient of the loss function with respect to the output of the einsum operation.block_size (
Union
[int
,AutomaticBlockSize
]) – Block size used to control memory usage.
- Return type:
- Returns:
The gradients with respect to each of the inputs to the einsum operation. Returns
None
for inputs that do not require gradient.
- torch_semiring_einsum.log_einsum(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE, save_max=True, save_sumexpsub=True, grad_of_neg_inf=nan)¶
Differentiable version of log-space einsum.
This combines
log_einsum_forward()
andlog_einsum_backward()
into one auto-differentiable function.- Parameters:
save_max (
bool
) – If true, save the tensor of maximum terms computed in the forward pass and reuse it in the backward pass. This tensor has the same size as the output tensor. Setting this to false will save memory but increase runtime.save_sumexpsub (
bool
) – If true, save the tensor of sums of terms computed in the forward pass and reuse it in the backward pass. This tensor has the same size as the output tensor. Setting this to false will save memory but increase runtime.
- Return type:
- torch_semiring_einsum.log_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE, return_max=False, return_sumexpsub=False)¶
Log-space einsum, where addition \(a + b\) is replaced with \(\log(\exp a + \exp b)\), and multiplication \(a \times b\) is replaced with addition \(a + b\).
- Parameters:
equation (
Equation
) – A pre-compiled equation.args (
Tensor
) – Input tensors. The number of input tensors must be compatible withequation
.block_size (
Union
[int
,AutomaticBlockSize
]) – Block size used to control memory usage.return_max (
bool
) – If true, also return the tensor of maximum terms, which can be reused when computing the gradient.return_sumexpsub (
bool
) – If true, also return the tensor of sums of terms (where the maximum term has been subtracted from each term), which can be reused when computing the gradient.
- Return type:
- Returns:
Output of einsum. If
return_max
orreturn_sumexpsub
is true, the output will be a list containing the extra outputs.
- torch_semiring_einsum.log_einsum_backward(equation, args, needs_grad, grad, block_size=AUTOMATIC_BLOCK_SIZE, grad_of_neg_inf=nan, saved_max=None, saved_sumexpsub=None)¶
Compute the derivative of
log_einsum_forward()
.Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.
- Parameters:
equation (
Equation
) – Pre-compiled einsum equation. The derivative of the log-space einsum operation specified by this equation will be computed.args (
Sequence
[Tensor
]) – The inputs to the log-space einsum operation whose derivative is being computed.needs_grad (
Sequence
[bool
]) – Indicates which inputs inargs
require gradient.grad (
Tensor
) – The gradient of the loss function with respect to the output of the log-space einsum operation.block_size (
Union
[int
,AutomaticBlockSize
]) – Block size used to control memory usage.grad_of_neg_inf (
Union
[float
,Literal
['uniform'
]]) – How to handle the gradient of cases where all inputs to a logsumexp are \(-\infty\), which results in an output of \(-\infty\). The default behavior is to output NaN, which matches the behavior of PyTorch’slogsumexp()
, but sometimes this is not desired. If afloat
is provided, all gradients will be set to that value. A value of0
, which causes the inputs not to change, may be appropriate. For example, if one input is a parameter and another is a constant \(-\infty\), it may not make sense to try to change the parameter. This is what the equivalent real space operation would do (the derivative of \(0x\) with respect to \(x\) is \(0\)). On the other hand, if the string'uniform'
is provided, the gradient will be set to a uniform distribution that sums to 1. This makes sense because the gradient of logsumexp is softmax, and in this case it will attempt to increase the inputs to the logsumexp above \(-\infty\). NOTE: Only NaN and 0 are currently implemented.saved_max (
Optional
[Tensor
]) – Seereturn_max
inlog_einsum_forward()
.saved_sumexpsub (
Optional
[Tensor
]) – Seereturn_sumexpsub
inlog_einsum_forward()
.
- Return type:
- Returns:
The gradients with respect to each of the inputs to the log-space einsum operation. Returns
None
for inputs that do not require gradient.
- torch_semiring_einsum.log_viterbi_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE)¶
Viterbi einsum, where addition \(a + b\) is replaced with \((\max(a, b), \arg \max(a, b))\), and multiplication \(a \times b\) is replaced with log-space multiplication \(a + b\).
- Parameters:
equation (
Equation
) – A pre-compiled equation.args (
Tensor
) – Input tensors. The number of input tensors must be compatible withequation
.block_size (
Union
[int
,AutomaticBlockSize
]) – Block size used to control memory usage.
- Return type:
- Returns:
A tuple containing the max and argmax of the einsum operation. The first element of the tuple simply contains the maximum values of the terms “summed” over by einsum. The second element contains the values of the summed-out variables that maximized those terms. If the max tensor has dimension \(N_1 \times \cdots \times N_m\), and \(k\) variables were summed out, then the argmax tensor has dimension \(N_1 \times \cdots \times N_m \times k\), where the \((m+1)\)th dimension is a \(k\)-tuple of indexes representing the argmax. The variables in the k-tuple are ordered by first appearance in the einsum equation.
- torch_semiring_einsum.AUTOMATIC_BLOCK_SIZE = AUTOMATIC_BLOCK_SIZE¶
Use this as
block_size
to determine block size automatically based on available memory, according to the default arguments forAutomaticBlockSize.__init__()
.
- class torch_semiring_einsum.AutomaticBlockSize¶
Indicates that the amount of memory used to sum elements in an einsum operation should be determined automatically based on the amount of available memory.
When the device is
cuda
, this automatically calculates the amount of free GPU memory on the current device and makes the block size as big as possible without exceeding it. When the device iscpu
, this uses the value ofmax_cpu_bytes
to determine how much memory it can use.- __init__(max_cpu_bytes=1073741824, max_cuda_bytes=None, cache_available_cuda_memory=True, cuda_memory_proportion=0.8, repr_string=None)¶
- Parameters:
max_cpu_bytes (
int
) – The maximum amount of memory (in bytes) to use when the device iscpu
. By default, this is set to 1 GiB.max_cuda_bytes (
Optional
[int
]) – The maximum amount of memory (in bytes) to use when the device iscuda
. IfNone
, then the amount of memory used will be determined based on the amount of free CUDA memory. Note that specifying an explicit memory limit is much faster than querying the amount of free CUDA memory.cache_available_cuda_memory (
bool
) – Only applies whenmax_cuda_bytes
isNone
. When true, the amount of available CUDA memory is only queried the first time einsum is called with this object asblock_size
, and it is reused on subsequent calls. This is significantly faster than querying the amount of available memory every time. To account for future decreases in the amount of available memory, only a portion of the available memory is used, as determined bycuda_memory_proportion
.cuda_memory_proportion (
float
) – Determines the proportion of available memory used whencache_available_cuda_memory
is true. This should be a number between 0 and 1.
- torch_semiring_einsum.semiring_einsum_forward(equation, args, block_size, func)¶
Implement a custom version of einsum using the callback
func
.This function is the main workhorse used to implement einsum for different semirings. It takes away the burden of figuring out how to index the input tensors and sum terms in a memory-efficient way, and only requires callbacks for performing addition and multiplication. It is also flexible enough to support multiple passes through the input tensors (this feature is required for implementing logsumexp). This function is used internally by the real, log, and Viterbi semiring einsum implementations in this package and can be used to implement einsum in other semirings as well.
Note that this function only implements the forward aspect of einsum and is not differentiable. To turn your instantiation of einsum in a particular semiring into a differentiable PyTorch
Function
, implement its derivative and usecombine()
to combine the forward and backward functions into one function. Odds are,semiring_einsum_forward()
can be used to implement the derivative efficiently as well (despite including “forward” in the name, there is nothing preventing you from using it as a tool in the backward step).semiring_einsum_forward()
will callfunc
asfunc(compute_sum)
, wherecompute_sum
is itself another function. Callingcompute_sum
executes a single einsum pass over the input tensors, where you supply custom functions for addition and multiplication; this is where the semiring customization really takes place.semiring_einsum_forward()
returns whatever you return fromfunc
, which will usually be what is returned fromcompute_sum
.func
will often consist of a single call tocompute_sum()
, but there are cases where multiple passes over the inputs with different semirings is useful (e.g. for a numerically stable logsumexp implementation, one must first compute maximum values and then use them for a subsequent logsumexp step).Here is a quick example that implements the equivalent of
torch.einsum()
:def regular_einsum(equation, *args, block_size): def func(compute_sum): def add_in_place(a, b): a += b def sum_block(a, dims): if dims: return torch.sum(a, dim=dims) else: # This is an edge case that `torch.sum` does not # handle correctly. return a def multiply_in_place(a, b): a *= b return compute_sum(add_in_place, sum_block, multiply_in_place) return semiring_einsum_forward(equation, args, block_size, func)
The full signature of
compute_sum
iscompute_sum(add_in_place, sum_block, multiply_in_place, include_indexes=False, output_dtypes=(None,))
. The+
and*
operators are customized usingadd_in_place
,sum_block
, andmultiply_in_place
.add_in_place(a, b)
must be a function that accepts two values and implementsa += b
for the desired definition of+
. Likewise,multiply_in_place(a, b)
must implementa *= b
for the desired definition of*
. The argumentsa
andb
are values returned fromsum_block
(see below) and are usually of typeTensor
, although they can be something fancier for cases like Viterbi (which involves a pair of tensors: max and argmax). These functions must modify the objecta
in-place; the return value is ignored.sum_block(a, dims)
should be a function that “sums” over multiple dimensions in a tensor at once. It must return its result.a
is always aTensor
.dims
is atuple
ofint
s representing the dimensions ina
to sum out. Take special care to handle the case wheredims
is an empty tuple – in particular, keep in mind thattorch.sum()
returns a scalar whendims
is an empty tuple. Simply returninga
is sufficient to handle this edge case. Note that it is always safe to return a view ofa
fromsum_block
, sincea
is itself never a view of the input tensors, but always a new tensor.If
include_indexes
isTrue
, thensum_block
will receive a third argument,var_values
, which contains the current indexes of the parts of the input tensors being summed over (sum_block
is called multiple times on different slices of the inputs).var_values
is atuple
ofrange
objects representing the ranges of indexes representing the current slice.var_values
contains an entry for each summed variable, in order of first appearance in the equation.The optional argument
output_dtypes
should be a list of PyTorch dtypes that represents the components of the output tensor. It is used to calculate the memory required by the output tensor when performing automatic block sizing. In most cases, the output tensor simply has the same dtype as the input tensor. In some cases, like Viterbi, the output tensor has multiple components (e.g. a tensor of floats for the max and a tensor of ints for the argmax). A dtype ofNone
can be used to indicate the same dtype as the input tensors. The default value foroutput_dtypes
is(None,)
.- Parameters:
equation (
Equation
) – A pre-compiled equation.block_size (
Union
[int
,AutomaticBlockSize
]) – To keep memory usage in check, the einsum summation is done over multiple “windows” or “blocks” of bounded size. This parameter sets the maximum size of these windows. More precisely, it defines the maximum size of the range of values of each summed variable that is included in a single window. If there aren
summed variables, the size of the window tensor is proportional toblock_size ** n
.func (
Callable
) – A callback of the form described above.
- torch_semiring_einsum.combine(forward, backward, forward_options=(), backward_options=())¶
Combine an einsum implementation and its derivative into a single function that works with PyTorch’s autograd mechanics.
Combining separate forward and backward implementations allows more memory efficiency than would otherwise be possible.
- Parameters:
forward (
Callable
) – The forward implementation.backward (
Callable
) – The backward implementation. Its signature must bebackward(equation, args, needs_grad, grad, block_size)
, and it must return atuple
ofTensor
containing the gradients with respect toargs
. The \(i\)th gradient may beNone
ifneeds_grad[i]
isFalse
.forward_options (
Tuple
[str
]) – A list of optional keyword arguments that should be passed to the forward function.backward_options (
Tuple
[str
]) – A list of optional keyword arguments that should be passed to the backward function.
- Return type:
- Returns:
A function whose return value is compatible with PyTorch’s autograd mechanics.