Flatten¶
- class torch.nn.Flatten(start_dim=1, end_dim=- 1)[source]¶
Flattens a contiguous range of dims into a tensor.
For use with
Sequential
, seetorch.flatten()
for details.- Shape:
Input: \((*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)\),’ where \(S_{i}\) is the size at dimension \(i\) and \(*\) means any number of dimensions including none.
Output: \((*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)\).
- Parameters:
- Examples::
>>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])