View source on GitHub 
Tensor contraction over specified indices and outer product.
tf.einsum( equation, *inputs, **kwargs )
Einsum allows defining Tensors by defining their elementwise computation. This computation is defined by equation
, a shorthand form based on Einstein summation. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by:
or
C[i,k] = sum_j A[i,j] * B[j,k]
The corresponding einsum equation
is:
ij,jk>ik
In general, to convert the elementwise equation into the equation
string, use the following procedure (intermediate strings for matrix multiplication example provided in parentheses):
ik = sum_j ij * jk
)ik = sum_j ij , jk
)ik = ij, jk
)ij,jk>ik
)Note: If the output indices are not specified repeated indices are summed. Soij,jk>ik
can be simplified toij,jk
.
Many common operations can be expressed in this way. For example:
Matrix multiplication
m0 = tf.random.normal(shape=[2, 3]) m1 = tf.random.normal(shape=[3, 5]) e = tf.einsum('ij,jk>ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] print(e.shape) (2, 5)
Repeated indices are summed if the output indices are not specified.
e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] print(e.shape) (2, 5)
Dot product
u = tf.random.normal(shape=[5]) v = tf.random.normal(shape=[5]) e = tf.einsum('i,i>', u, v) # output = sum_i u[i]*v[i] print(e.shape) ()
Outer product
u = tf.random.normal(shape=[3]) v = tf.random.normal(shape=[5]) e = tf.einsum('i,j>ij', u, v) # output[i,j] = u[i]*v[j] print(e.shape) (3, 5)
Transpose
m = tf.ones(2,3) e = tf.einsum('ij>ji', m0) # output[j,i] = m0[i,j] print(e.shape) (3, 2)
Diag
m = tf.reshape(tf.range(9), [3,3]) diag = tf.einsum('ii>i', m) print(diag.shape) (3,)
Trace
# Repeated indices are summed. trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] assert trace == sum(diag) print(trace.shape) ()
Batch matrix multiplication
s = tf.random.normal(shape=[7,5,3]) t = tf.random.normal(shape=[7,3,2]) e = tf.einsum('bij,bjk>bik', s, t) # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k] print(e.shape) (7, 5, 2)
This method does not support broadcasting on namedaxes. All axes with matching labels should have the same length. If you have length1 axes, use tf.squeseze
or tf.reshape
to eliminate them.
To write code that is agnostic to the number of indices in the input use an ellipsis. The ellipsis is a placeholder for "whatever other indices fit here".
For example, to perform a NumPystyle broadcastingbatchmatrix multiplication where the matrix multiply acts on the last two axes of the input, use:
s = tf.random.normal(shape=[11, 7, 5, 3]) t = tf.random.normal(shape=[11, 7, 3, 2]) e = tf.einsum('...ij,...jk>...ik', s, t) print(e.shape) (11, 7, 5, 2)
Einsum will broadcast over axes covered by the ellipsis.
s = tf.random.normal(shape=[11, 1, 5, 3]) t = tf.random.normal(shape=[1, 7, 3, 2]) e = tf.einsum('...ij,...jk>...ik', s, t) print(e.shape) (11, 7, 5, 2)
Args  

equation  a str describing the contraction, in the same format as numpy.einsum . 
*inputs  the inputs to contract (each one a Tensor ), whose shapes should be consistent with equation . 
**kwargs 

Returns  

The contracted Tensor , with shape determined by equation . 
Raises  

ValueError  If

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/einsum