NBeatsBlock¶
- class NBeatsBlock(input_size: int, theta_size: int, basis_function: torch.nn.modules.module.Module, num_layers: int, layer_size: int)[source]¶
Bases:
torch.nn.modules.module.Module
Base N-BEATS block which takes a basis function as an argument.
N-BEATS block.
- Parameters
input_size (int) – In-sample size.
theta_size (int) – Number of parameters for the basis function.
basis_function (nn.Module) – Basis function which takes the parameters and produces backcast and forecast.
num_layers (int) – Number of layers.
layer_size (int) – Layer size.
Methods
forward
(x)Forward pass.
Attributes
- forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] [source]¶
Forward pass.
- Parameters
x (torch.Tensor) – Input data.
- Returns
Tuple with backcast and forecast.
- Return type
Tuple[torch.Tensor, torch.Tensor]