torch.logsumexp¶
- torch.logsumexp(input, dim, keepdim=False, *, out=None)¶
Returns the log of summed exponentials of each row of the
input
tensor in the given dimensiondim
. The computation is numerically stabilized.For summation index \(j\) given by dim and other indices \(i\), the result is
\[\text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) \]If
keepdim
isTrue
, the output tensor is of the same size asinput
except in the dimension(s)dim
where it is of size 1. Otherwise,dim
is squeezed (seetorch.squeeze()
), resulting in the output tensor having 1 (orlen(dim)
) fewer dimension(s).- Parameters:
- Keyword Arguments:
out (Tensor, optional) – the output tensor.
Example:
>>> a = torch.randn(3, 3) >>> torch.logsumexp(a, 1) tensor([1.4907, 1.0593, 1.5696]) >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) tensor(1.6859e-07)