Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order.
multi_dot chains numpy.dot and uses optimal parenthesization of the matrices [1] [2]. Depending on the shapes of the matrices, this can speed up the multiplication a lot.
If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.
Think of multi_dot as:
def multi_dot(arrays): return functools.reduce(np.dot, arrays)
If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D.
Output argument. This must have the exact kind that would be returned if it was not used. In particular, it must have the right type, must be C-contiguous, and its dtype must be the dtype that would be returned for dot(a, b). This is a performance feature. Therefore, if these conditions are not met, an exception is raised, instead of attempting to be flexible.
Returns the dot product of the supplied arrays.
See also
numpy.dotdot multiplication with two arguments.
The cost for a matrix multiplication can be calculated with the following function:
def cost(A, B):
return A.shape[0] * A.shape[1] * B.shape[1]
Assume we have three matrices \(A_{10 \times 100}, B_{100 \times 5}, C_{5 \times 50}\).
The costs for the two different parenthesizations are as follows:
cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500 cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
Cormen, “Introduction to Algorithms”, Chapter 15.2, p. 370-378
multi_dot allows you to write:
>>> import numpy as np >>> from numpy.linalg import multi_dot >>> # Prepare some data >>> A = np.random.random((10000, 100)) >>> B = np.random.random((100, 1000)) >>> C = np.random.random((1000, 5)) >>> D = np.random.random((5, 333)) >>> # the actual dot multiplication >>> _ = multi_dot([A, B, C, D])
instead of:
>>> _ = np.dot(np.dot(np.dot(A, B), C), D) >>> # or >>> _ = A.dot(B).dot(C).dot(D)
© 2005–2024 NumPy Developers
Licensed under the 3-clause BSD License.
https://numpy.org/doc/2.4/reference/generated/numpy.linalg.multi_dot.html