NBeats¶
- class NBeats(blocks: torch.nn.modules.container.ModuleList)[source]¶
Bases:
torch.nn.modules.module.Module
N-BEATS model.
Initialize N-BEATS model.
- Parameters
blocks (nn.ModuleList) – Model blocks.
Methods
forward
(x, input_mask)Forward pass.
Attributes
- forward(x: torch.Tensor, input_mask: torch.Tensor) torch.Tensor [source]¶
Forward pass.
- Parameters
x (torch.Tensor) – Input data.
input_mask (torch.Tensor) – Input mask.
- Returns
Forecast tensor.
- Return type