API for torch_semiring_einsum

torch_semiring_einsum.compile_equation(equation)

Pre-compile an einsum equation for use with the einsum functions in this package.

Parameters:

equation (str) – An equation in einsum syntax.

Return type:

Equation

Returns:

A pre-compiled equation.

class torch_semiring_einsum.Equation

An einsum equation that has been pre-compiled into some useful data structures.

__init__(source, variable_locations, input_variables, output_variables, num_variables)
torch_semiring_einsum.einsum(equation, *args, block_size, **kwargs)

Differentiable version of ordinary (real) einsum.

This combines real_einsum_forward() and real_einsum_backward() into one auto-differentiable function.

torch_semiring_einsum.real_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE)

Einsum where addition and multiplication have their usual meanings.

This function has different memory and runtime characteristics than torch.einsum(), which can be tuned with block_size. Higher values of block_size result in faster runtime and higher memory usage.

In some cases, when dealing with summations over more than two input tensors at once, this implementation can have better space complexity than torch.einsum(), because it does not create intermediate tensors whose sizes are proportional to the dimensions being summed over.

Parameters:
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (Union[int, AutomaticBlockSize]) – Block size used to control memory usage.

Return type:

Tensor

Returns:

Output of einsum.

torch_semiring_einsum.real_einsum_backward(equation, args, needs_grad, grad, block_size=AUTOMATIC_BLOCK_SIZE)

Compute the derivative of real_einsum_forward().

Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.

Parameters:
  • equation (Equation) – Pre-compiled einsum equation. The derivative of the einsum operation specified by this equation will be computed.

  • args (Sequence[Tensor]) – The inputs to the einsum operation whose derivative is being computed.

  • needs_grad (Sequence[bool]) – Indicates which inputs in args require gradient.

  • grad (Tensor) – The gradient of the loss function with respect to the output of the einsum operation.

  • block_size (Union[int, AutomaticBlockSize]) – Block size used to control memory usage.

Return type:

List[Optional[Tensor]]

Returns:

The gradients with respect to each of the inputs to the einsum operation. Returns None for inputs that do not require gradient.

torch_semiring_einsum.log_einsum(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE, save_max=True, save_sumexpsub=True, grad_of_neg_inf=nan)

Differentiable version of log-space einsum.

This combines log_einsum_forward() and log_einsum_backward() into one auto-differentiable function.

Parameters:
  • save_max (bool) – If true, save the tensor of maximum terms computed in the forward pass and reuse it in the backward pass. This tensor has the same size as the output tensor. Setting this to false will save memory but increase runtime.

  • save_sumexpsub (bool) – If true, save the tensor of sums of terms computed in the forward pass and reuse it in the backward pass. This tensor has the same size as the output tensor. Setting this to false will save memory but increase runtime.

Return type:

Tensor

torch_semiring_einsum.log_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE, return_max=False, return_sumexpsub=False)

Log-space einsum, where addition \(a + b\) is replaced with \(\log(\exp a + \exp b)\), and multiplication \(a \times b\) is replaced with addition \(a + b\).

Parameters:
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (Union[int, AutomaticBlockSize]) – Block size used to control memory usage.

  • return_max (bool) – If true, also return the tensor of maximum terms, which can be reused when computing the gradient.

  • return_sumexpsub (bool) – If true, also return the tensor of sums of terms (where the maximum term has been subtracted from each term), which can be reused when computing the gradient.

Return type:

Union[Tensor, Tuple[Tensor]]

Returns:

Output of einsum. If return_max or return_sumexpsub is true, the output will be a list containing the extra outputs.

torch_semiring_einsum.log_einsum_backward(equation, args, needs_grad, grad, block_size=AUTOMATIC_BLOCK_SIZE, grad_of_neg_inf=nan, saved_max=None, saved_sumexpsub=None)

Compute the derivative of log_einsum_forward().

Like the forward pass, the backward pass is done in memory-efficient fashion by doing summations in fixed-size chunks.

Parameters:
  • equation (Equation) – Pre-compiled einsum equation. The derivative of the log-space einsum operation specified by this equation will be computed.

  • args (Sequence[Tensor]) – The inputs to the log-space einsum operation whose derivative is being computed.

  • needs_grad (Sequence[bool]) – Indicates which inputs in args require gradient.

  • grad (Tensor) – The gradient of the loss function with respect to the output of the log-space einsum operation.

  • block_size (Union[int, AutomaticBlockSize]) – Block size used to control memory usage.

  • grad_of_neg_inf (Union[float, Literal['uniform']]) – How to handle the gradient of cases where all inputs to a logsumexp are \(-\infty\), which results in an output of \(-\infty\). The default behavior is to output NaN, which matches the behavior of PyTorch’s logsumexp(), but sometimes this is not desired. If a float is provided, all gradients will be set to that value. A value of 0, which causes the inputs not to change, may be appropriate. For example, if one input is a parameter and another is a constant \(-\infty\), it may not make sense to try to change the parameter. This is what the equivalent real space operation would do (the derivative of \(0x\) with respect to \(x\) is \(0\)). On the other hand, if the string 'uniform' is provided, the gradient will be set to a uniform distribution that sums to 1. This makes sense because the gradient of logsumexp is softmax, and in this case it will attempt to increase the inputs to the logsumexp above \(-\infty\). NOTE: Only NaN and 0 are currently implemented.

  • saved_max (Optional[Tensor]) – See return_max in log_einsum_forward().

  • saved_sumexpsub (Optional[Tensor]) – See return_sumexpsub in log_einsum_forward().

Return type:

List[Optional[Tensor]]

Returns:

The gradients with respect to each of the inputs to the log-space einsum operation. Returns None for inputs that do not require gradient.

torch_semiring_einsum.log_viterbi_einsum_forward(equation, *args, block_size=AUTOMATIC_BLOCK_SIZE)

Viterbi einsum, where addition \(a + b\) is replaced with \((\max(a, b), \arg \max(a, b))\), and multiplication \(a \times b\) is replaced with log-space multiplication \(a + b\).

Parameters:
  • equation (Equation) – A pre-compiled equation.

  • args (Tensor) – Input tensors. The number of input tensors must be compatible with equation.

  • block_size (Union[int, AutomaticBlockSize]) – Block size used to control memory usage.

Return type:

Tuple[Tensor, LongTensor]

Returns:

A tuple containing the max and argmax of the einsum operation. The first element of the tuple simply contains the maximum values of the terms “summed” over by einsum. The second element contains the values of the summed-out variables that maximized those terms. If the max tensor has dimension \(N_1 \times \cdots \times N_m\), and \(k\) variables were summed out, then the argmax tensor has dimension \(N_1 \times \cdots \times N_m \times k\), where the \((m+1)\)th dimension is a \(k\)-tuple of indexes representing the argmax. The variables in the k-tuple are ordered by first appearance in the einsum equation.

torch_semiring_einsum.AUTOMATIC_BLOCK_SIZE = AUTOMATIC_BLOCK_SIZE

Use this as block_size to determine block size automatically based on available memory, according to the default arguments for AutomaticBlockSize.__init__().

class torch_semiring_einsum.AutomaticBlockSize

Indicates that the amount of memory used to sum elements in an einsum operation should be determined automatically based on the amount of available memory.

When the device is cuda, this automatically calculates the amount of free GPU memory on the current device and makes the block size as big as possible without exceeding it. When the device is cpu, this uses the value of max_cpu_bytes to determine how much memory it can use.

__init__(max_cpu_bytes=1073741824, max_cuda_bytes=None, cache_available_cuda_memory=True, cuda_memory_proportion=0.8, repr_string=None)
Parameters:
  • max_cpu_bytes (int) – The maximum amount of memory (in bytes) to use when the device is cpu. By default, this is set to 1 GiB.

  • max_cuda_bytes (Optional[int]) – The maximum amount of memory (in bytes) to use when the device is cuda. If None, then the amount of memory used will be determined based on the amount of free CUDA memory. Note that specifying an explicit memory limit is much faster than querying the amount of free CUDA memory.

  • cache_available_cuda_memory (bool) – Only applies when max_cuda_bytes is None. When true, the amount of available CUDA memory is only queried the first time einsum is called with this object as block_size, and it is reused on subsequent calls. This is significantly faster than querying the amount of available memory every time. To account for future decreases in the amount of available memory, only a portion of the available memory is used, as determined by cuda_memory_proportion.

  • cuda_memory_proportion (float) – Determines the proportion of available memory used when cache_available_cuda_memory is true. This should be a number between 0 and 1.

torch_semiring_einsum.semiring_einsum_forward(equation, args, block_size, func)

Implement a custom version of einsum using the callback func.

This function is the main workhorse used to implement einsum for different semirings. It takes away the burden of figuring out how to index the input tensors and sum terms in a memory-efficient way, and only requires callbacks for performing addition and multiplication. It is also flexible enough to support multiple passes through the input tensors (this feature is required for implementing logsumexp). This function is used internally by the real, log, and Viterbi semiring einsum implementations in this package and can be used to implement einsum in other semirings as well.

Note that this function only implements the forward aspect of einsum and is not differentiable. To turn your instantiation of einsum in a particular semiring into a differentiable PyTorch Function, implement its derivative and use combine() to combine the forward and backward functions into one function. Odds are, semiring_einsum_forward() can be used to implement the derivative efficiently as well (despite including “forward” in the name, there is nothing preventing you from using it as a tool in the backward step).

semiring_einsum_forward() will call func as func(compute_sum), where compute_sum is itself another function. Calling compute_sum executes a single einsum pass over the input tensors, where you supply custom functions for addition and multiplication; this is where the semiring customization really takes place. semiring_einsum_forward() returns whatever you return from func, which will usually be what is returned from compute_sum. func will often consist of a single call to compute_sum(), but there are cases where multiple passes over the inputs with different semirings is useful (e.g. for a numerically stable logsumexp implementation, one must first compute maximum values and then use them for a subsequent logsumexp step).

Here is a quick example that implements the equivalent of torch.einsum():

def regular_einsum(equation, *args, block_size):
    def func(compute_sum):
        def add_in_place(a, b):
            a += b
        def sum_block(a, dims):
            if dims:
                return torch.sum(a, dim=dims)
            else:
                # This is an edge case that `torch.sum` does not
                # handle correctly.
                return a
        def multiply_in_place(a, b):
            a *= b
        return compute_sum(add_in_place, sum_block, multiply_in_place)
    return semiring_einsum_forward(equation, args, block_size, func)

The full signature of compute_sum is compute_sum(add_in_place, sum_block, multiply_in_place, include_indexes=False, output_dtypes=(None,)). The + and * operators are customized using add_in_place, sum_block, and multiply_in_place.

add_in_place(a, b) must be a function that accepts two values and implements a += b for the desired definition of +. Likewise, multiply_in_place(a, b) must implement a *= b for the desired definition of *. The arguments a and b are values returned from sum_block (see below) and are usually of type Tensor, although they can be something fancier for cases like Viterbi (which involves a pair of tensors: max and argmax). These functions must modify the object a in-place; the return value is ignored.

sum_block(a, dims) should be a function that “sums” over multiple dimensions in a tensor at once. It must return its result. a is always a Tensor. dims is a tuple of ints representing the dimensions in a to sum out. Take special care to handle the case where dims is an empty tuple – in particular, keep in mind that torch.sum() returns a scalar when dims is an empty tuple. Simply returning a is sufficient to handle this edge case. Note that it is always safe to return a view of a from sum_block, since a is itself never a view of the input tensors, but always a new tensor.

If include_indexes is True, then sum_block will receive a third argument, var_values, which contains the current indexes of the parts of the input tensors being summed over (sum_block is called multiple times on different slices of the inputs). var_values is a tuple of range objects representing the ranges of indexes representing the current slice. var_values contains an entry for each summed variable, in order of first appearance in the equation.

The optional argument output_dtypes should be a list of PyTorch dtypes that represents the components of the output tensor. It is used to calculate the memory required by the output tensor when performing automatic block sizing. In most cases, the output tensor simply has the same dtype as the input tensor. In some cases, like Viterbi, the output tensor has multiple components (e.g. a tensor of floats for the max and a tensor of ints for the argmax). A dtype of None can be used to indicate the same dtype as the input tensors. The default value for output_dtypes is (None,).

Parameters:
  • equation (Equation) – A pre-compiled equation.

  • args (Sequence[Tensor]) – A list of input tensors.

  • block_size (Union[int, AutomaticBlockSize]) – To keep memory usage in check, the einsum summation is done over multiple “windows” or “blocks” of bounded size. This parameter sets the maximum size of these windows. More precisely, it defines the maximum size of the range of values of each summed variable that is included in a single window. If there are n summed variables, the size of the window tensor is proportional to block_size ** n.

  • func (Callable) – A callback of the form described above.

torch_semiring_einsum.combine(forward, backward, forward_options=(), backward_options=())

Combine an einsum implementation and its derivative into a single function that works with PyTorch’s autograd mechanics.

Combining separate forward and backward implementations allows more memory efficiency than would otherwise be possible.

Parameters:
  • forward (Callable) – The forward implementation.

  • backward (Callable) – The backward implementation. Its signature must be backward(equation, args, needs_grad, grad, block_size), and it must return a tuple of Tensor containing the gradients with respect to args. The \(i\)th gradient may be None if needs_grad[i] is False.

  • forward_options (Tuple[str]) – A list of optional keyword arguments that should be passed to the forward function.

  • backward_options (Tuple[str]) – A list of optional keyword arguments that should be passed to the backward function.

Return type:

Callable

Returns:

A function whose return value is compatible with PyTorch’s autograd mechanics.