NBeatsInterpretableNet¶
- class NBeatsInterpretableNet(input_size: int, output_size: int, loss: torch.nn.modules.module.Module, trend_blocks: int, trend_layers: int, trend_layer_size: int, degree_of_polynomial: int, seasonality_blocks: int, seasonality_layers: int, seasonality_layer_size: int, num_of_harmonics: int, lr: float, optimizer_params: Optional[Dict[str, Any]] = None)[source]¶
Bases:
etna.models.nn.nbeats.nets.NBeatsBaseNet
Interpretable N-BEATS model.
Initialize N-BEATS model.
- Parameters
input_size (int) – Input data size.
output_size (int) – Forecast size.
loss (torch.nn.Module) – Optimisation objective. The loss function should accept three arguments:
y_true
,y_pred
andmask
. The last parameter is a binary mask that denotes which points are valid forecasts.trend_blocks (int) – Number of trend blocks.
trend_layers (int) – Number of inner layers in each trend block.
trend_layer_size (int) – Inner layer size in trend blocks.
degree_of_polynomial (int) – Polynomial degree for trend modeling.
seasonality_blocks (int) – Number of seasonality blocks.
seasonality_layers (int) – Number of inner layers in each seasonality block.
seasonality_layer_size (int) – Inner layer size in seasonality blocks.
num_of_harmonics (int) – Number of harmonics for seasonality estimation.
lr (float) – Optimizer learning rate.
optimizer_params (Optional[Dict[str, Any]]) – Additional parameters for the optimizer.
Attributes