import pathlib
import tempfile
import zipfile
from copy import deepcopy
from typing import List
from typing import Optional
import pandas as pd
from typing_extensions import Self
from etna.core import SaveMixin
from etna.core import load
from etna.datasets import TSDataset
from etna.loggers import tslogger
from etna.pipeline.base import BasePipeline
[docs]class EnsembleMixin:
"""Base mixin for the ensembles."""
@staticmethod
def _validate_pipeline_number(pipelines: List[BasePipeline]):
"""Check that given valid number of pipelines."""
if len(pipelines) < 2:
raise ValueError("At least two pipelines are expected.")
@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) > 1:
raise ValueError("All the pipelines should have the same horizon.")
return horizons.pop()
@staticmethod
def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline:
"""Fit given pipeline with ``ts``."""
tslogger.log(msg=f"Start fitting {pipeline}.")
pipeline.fit(ts=ts)
tslogger.log(msg=f"Pipeline {pipeline} is fitted.")
return pipeline
@staticmethod
def _forecast_pipeline(pipeline: BasePipeline, ts: TSDataset) -> TSDataset:
"""Make forecast with given pipeline."""
tslogger.log(msg=f"Start forecasting with {pipeline}.")
forecast = pipeline.forecast(ts=ts)
tslogger.log(msg=f"Forecast is done with {pipeline}.")
return forecast
@staticmethod
def _predict_pipeline(
ts: TSDataset,
pipeline: BasePipeline,
start_timestamp: Optional[pd.Timestamp],
end_timestamp: Optional[pd.Timestamp],
) -> TSDataset:
"""Make predict with given pipeline."""
tslogger.log(msg=f"Start prediction with {pipeline}.")
prediction = pipeline.predict(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp)
tslogger.log(msg=f"Prediction is done with {pipeline}.")
return prediction
[docs]class SaveEnsembleMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines.
It saves object to the zip archive with 3 entities:
* metadata.json: contains library version and class name.
* object.pkl: pickled without pipelines and ts.
* pipelines: folder with saved pipelines.
"""
pipelines: List[BasePipeline]
ts: Optional[TSDataset]
[docs] def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")
# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts
with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
# save transforms separately
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = 8
for i, pipeline in enumerate(pipelines):
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
archive.write(pipeline_save_path, f"pipelines/{save_name}")
[docs] @classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.
Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.
Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj.ts = deepcopy(ts)
with zipfile.ZipFile(path, "r") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
archive.extractall(temp_dir)
# load pipelines
pipelines_dir = temp_dir / "pipelines"
pipelines = []
for path in sorted(pipelines_dir.iterdir()):
pipelines.append(load(path, ts=ts))
obj.pipelines = pipelines
return obj