torch.tensordot¶
- torch.tensordot(a, b, dims=2, out=None)[source]¶
Returns a contraction of a and b over multiple dimensions.
tensordot
implements a generalized matrix product.- Parameters:
When called with a non-negative integer argument
dims
= \(d\), and the number of dimensions ofa
andb
is \(m\) and \(n\), respectively,tensordot()
computes\[r_{i_0,...,i_{m-d}, i_d,...,i_n} = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. \]When called with
dims
of the list form, the given dimensions will be contracted in place of the last \(d\) ofa
and the first \(d\) of \(b\). The sizes in these dimensions must match, buttensordot()
will deal with broadcasted dimensions.Examples:
>>> a = torch.arange(60.).reshape(3, 4, 5) >>> b = torch.arange(24.).reshape(4, 3, 2) >>> torch.tensordot(a, b, dims=([1, 0], [0, 1])) tensor([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]]) >>> a = torch.randn(3, 4, 5, device='cuda') >>> b = torch.randn(4, 5, 6, device='cuda') >>> c = torch.tensordot(a, b, dims=2).cpu() tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) >>> a = torch.randn(3, 5, 4, 6) >>> b = torch.randn(6, 4, 5, 3) >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) tensor([[ 7.7193, -2.4867, -10.3204], [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]])