GRUCell¶
- class torch.nn.GRUCell(input_size, hidden_size, bias=True, device=None, dtype=None)[source]¶
A gated recurrent unit (GRU) cell.
\[\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ h' = (1 - z) \odot n + z \odot h \end{array}\]where \(\sigma\) is the sigmoid function, and \(\odot\) is the Hadamard product.
- Parameters:
- Inputs: input, hidden
input : tensor containing input features
hidden : tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided.
- Outputs: h’
h’ : tensor containing the next hidden state for each element in the batch
- Shape:
input: \((N, H_{in})\) or \((H_{in})\) tensor containing input features where \(H_{in}\) = input_size.
hidden: \((N, H_{out})\) or \((H_{out})\) tensor containing the initial hidden state where \(H_{out}\) = hidden_size. Defaults to zero if not provided.
output: \((N, H_{out})\) or \((H_{out})\) tensor containing the next hidden state.
- Variables:
weight_ih (torch.Tensor) – the learnable input-hidden weights, of shape (3*hidden_size, input_size)
weight_hh (torch.Tensor) – the learnable hidden-hidden weights, of shape (3*hidden_size, hidden_size)
bias_ih – the learnable input-hidden bias, of shape (3*hidden_size)
bias_hh – the learnable hidden-hidden bias, of shape (3*hidden_size)
Note
All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden\_size}}\)
On certain ROCm devices, when using float16 inputs this module will use different precision for backward.
Examples:
>>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx)