NBeatsMSE¶
- class NBeatsMSE[source]¶
Bases:
torch.nn.modules.module.Module
MSE with mask.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
forward
(y_true, y_pred, mask)Compute metric.
Attributes
- forward(y_true: torch.Tensor, y_pred: torch.Tensor, mask: torch.Tensor) torch.Tensor [source]¶
Compute metric.
- Parameters
y_true (torch.Tensor) – True target.
y_pred (torch.Tensor) – Predicted target.
mask (torch.Tensor) – Binary mask that denotes which points are valid forecasts.
- Returns
Metric value.
- Return type