Semiring Einsum (torch_semiring_einsum)

View on GitHub

This is a PyTorch re-implementation of einsum that supports multiple semirings. It includes implementations for the real, log, and Viterbi semirings out of the box and can be extended to support additional semirings. It can also offer better performance than the built-in torch.einsum() function and makes the memory-execution time tradeoff configurable, allowing you to run large einsum operations that might otherwise be impossible given typical hardware constraints.

This einsum implementation was specifically designed to be memory-efficient. Whereas a naive implementation of einsum could easily consume huge amounts of memory, this implementation has a very conservative memory footprint. It performs summations in-place and in fixed-size blocks in order to force an upper bound on memory usage, at the cost of some parallelism. However, the percentage of memory saved is typically much greater than the percentage of speed lost. This einsum implementation allows you to set the block size and tune the tradeoff between memory and speed.

In some cases, this einsum implementation has even better space complexity than the built-in torch.einsum() function, because it does not need to create intermediate tensors whose sizes are proportional to the dimensions being summed over.

Installation

You can install torch_semiring_einsum from PyPI using pip:

pip install torch_semiring_einsum

or a package manager like Poetry:

poetry add torch_semiring_einsum

You can also install it directly from GitHub:

pip install git+git://github.com/bdusell/semiring-einsum.git
poetry add git+https://github.com/bdusell/semiring-einsum@master

Basic Usage

Here is a quick example that implements batched matrix multiplication in log space:

import torch
import torch_semiring_einsum

# Pre-compile an einsum equation.
EQUATION = torch_semiring_einsum.compile_equation('bik,bkj->bij')
# Create some parameters to multiply.
A = torch.log(torch.rand(10, 3, 5, requires_grad=True))
B = torch.log(torch.rand(10, 5, 7, requires_grad=True))
# Run einsum.
C = torch_semiring_einsum.log_einsum(EQUATION, A, B)
# Now C is differentiable.
C.sum().backward()

Note that unlike in NumPy or PyTorch, equations are pre-compiled using compile_equation() rather than re-parsed from scratch every time einsum is called.

API Documentation

For full, detailed API documentation, see API for torch_semiring_einsum.

What is Einsum?

The so-called “einsum” function, offered in tensor math libraries such as NumPy, TensorFlow, and PyTorch, is a function that can be used to express multi-dimensional, linear algebraic tensor operations with a simple, concise syntax inspired by Einstein summation. It is a very useful kernel that can be used to implement other tensor operations; for example, the matrix-matrix product of A and B can be implemented as

C = einsum('ik,kj->ij', A, B)

In this example, the first argument to the function is the “equation,” and the lower-case letters i, j, and k all serve as labels for dimensions of the tensors A, B, and C. The left side of the equation, ik,kj, describes the dimensions of the inputs, A and B; the right side of the equation, ij, describes the desired shape of the output tensor C. This means that for each i and j, entry C[i, j] will be formed by multiplying elements from A[i, :] and B[:, j]. Since the variable k does not appear in the output, it is “summed out,” meaning that each C[i, j] is the result of computing A[i, k] * B[k, j] for each k, then summing over the resulting terms.

\[C_{ij} = \sum_k A_{ik} \times B_{kj}\]

Einsum can also be used with three or more tensor arguments.

Semirings

It is often useful to swap out addition and multiplication for different operators that have the same algebraic properties as addition and multiplication do on real numbers. We can express this using semirings. Changing the semiring used by a piece of code can result in new, useful algorithms. For example, the Viterbi Algorithm and the Forward Algorithm on Hidden Markov Models can be viewed as instances of the same algorithm instantiated with different semirings.

For a formal definition of semirings and an introduction to semirings in the context of context-free grammar parsing, see [Goo99].

Einsum Syntax

This package supports the same einsum equation syntax as torch.einsum(), except it does not support ellipses (...) syntax.

Time and Space Complexity

Consider the einsum equation 'ak,ak,ak->a', where \(A\) is the size of the a dimension and \(K\) is the size of the k dimension. Implementations of einsum in NumPy and PyTorch would compute this by contracting two tensors at a time, performing two separate tensor multiplications. This means that they must create an intermediate tensor of size \(A \times K\). There is even a routine in NumPy, numpy.einsum_path(), which figures out the best contraction order. However, it should, in principle, be possible to avoid this by summing over all tensors at the same time. This is exactly what torch_semiring_einsum does, and as a result the amount of scratch space the forward pass of einsum requires remains fixed as a function of \(K\).

In addition to performing the summations in the forward and backward passes in-place, this package implements another important innovation: performing summations in blocks of fixed size. Crucially, this allows you to strike a balance between time and memory usage, allowing you to perform einsum operations that might otherwise be impossible given typical time and GPU memory constraints.

The fixed-block method is a compromise between two extremes: (a) performing the summation in-place by iterating over every value of k one-by-one, and (b) performing the summation entirely out-of-place by creating an intermediate tensor with a new k dimension of size \(K\), then summing over k in one GPU kernel call. Method (a) is unbearably slow, and method (b) can use exorbitant amounts of memory that make it impossible to use. The fixed-block method is like method (a), except that it iterates over fixed-size ranges of k. This increases the parallelism and memory requirements of the summation calculation and decreases the number of GPU kernels launched. Smaller blocks make einsum behave more like (a), and larger blocks make it behave more like (b). But in all cases, the fixed block size ensures that the memory requirements never scale with \(K\), so the space complexity for our example would remain \(O(A)\) instead of \(O(AK)\).

These plots show how the space and time complexity of einsum('ak,ak,ak->a') (using the real semiring) varies with block size and \(K\), the size of dimension k:

_images/time-complexity.png _images/time-complexity-2.png

As we can see, execution time gets dramatically better even with small increases in block size. The built-in torch.einsum() function is still much faster than the blocked versions, but when the block size is unbounded and the summation is fully parallel, it is even faster.

_images/space-complexity.png _images/space-complexity-2.png

For our example, the built-in einsum implementation uses the same amount of memory as the fully out-of-place einsum (this is true for this specific equation, but it does not generally hold true for all equations). Crucially, the blocked einsum implementation has constant, rather than linear, space complexity, opening up a new world of possible einsum operations.

Bibliography

Goo99

Joshua Goodman. Semiring parsing. Computational Linguistics, 25(4):573–605, 1999.