Source code for etna.transforms.decomposition.change_points_based.per_interval_models.constant

from typing import Optional

import numpy as np

from etna.transforms.decomposition.change_points_based.per_interval_models.base import PerIntervalModel


[docs]class ConstantPerIntervalModel(PerIntervalModel): """ConstantPerIntervalModel gives a constant prediction it was fitted with.""" def __init__(self): """Init ConstantPerIntervalModel.""" self.value: Optional[float] = None
[docs] def fit(self, features: np.ndarray, target: np.ndarray, *args, **kwargs) -> "ConstantPerIntervalModel": """Fit constant model. Parameters ---------- features: features to fit model, will be ignored target: target to fit model, will be ignored Returns ------- self: fitted ConstantPerIntervalModel """ self.value = kwargs.get("value", None) return self
[docs] def predict(self, features: np.ndarray, *args, **kwargs) -> np.ndarray: """Predict with constant. Parameters ---------- features: features to make prediction for Returns ------- prediction: constant array of features' len """ prediction = np.full(shape=(features.shape[0],), fill_value=self.value) return prediction