import base64
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from uuid import uuid4
import pandas as pd
from etna import SETTINGS
from etna.loggers.base import BaseLogger
if TYPE_CHECKING:
from pytorch_lightning.loggers import WandbLogger as PLWandbLogger
from etna.datasets import TSDataset
if SETTINGS.wandb_required:
import wandb
[docs]class WandbLogger(BaseLogger):
"""Weights&Biases logger."""
def __init__(
self,
name: Optional[str] = None,
entity: Optional[str] = None,
project: Optional[str] = None,
job_type: Optional[str] = None,
group: Optional[str] = None,
tags: Optional[List[str]] = None,
plot: bool = True,
table: bool = True,
name_prefix: str = "",
config: Optional[Dict[str, Any]] = None,
log_model: bool = False,
):
"""
Create instance of WandbLogger.
Parameters
----------
name:
Wandb run name.
entity:
An entity is a username or team name where you're sending runs.
project:
The name of the project where you're sending the new run
job_type:
Specify the type of run, which is useful when you're grouping runs together
into larger experiments using group.
group:
Specify a group to organize individual runs into a larger experiment.
tags:
A list of strings, which will populate the list of tags on this run in the UI.
plot:
Indicator for making and sending plots.
table:
Indicator for making and sending tables.
name_prefix:
Prefix for the name field.
config:
This sets `wandb.config`, a dictionary-like object for saving inputs to your job,
like hyperparameters for a model or settings for a data preprocessing job.
"""
super().__init__()
self.name = (
name_prefix + base64.urlsafe_b64encode(uuid4().bytes).decode("utf8").rstrip("=\n")[:8]
if name is None
else name
)
self.project = project
self.entity = entity
self.group = group
self.config = config
self._experiment = None
self._pl_logger: Optional["PLWandbLogger"] = None
self.job_type = job_type
self.tags = tags
self.plot = plot
self.table = table
self.name_prefix = name_prefix
self.log_model = log_model
[docs] def log(self, msg: Union[str, Dict[str, Any]], **kwargs):
"""
Log any event.
e.g. "Fitted segment segment_name" to stderr output.
Parameters
----------
msg:
Message or dict to log
kwargs:
Parameters for changing additional info in log message
Notes
-----
We log dictionary to wandb only.
"""
if isinstance(msg, dict):
self.experiment.log(msg)
[docs] def log_backtest_metrics(
self, ts: "TSDataset", metrics_df: pd.DataFrame, forecast_df: pd.DataFrame, fold_info_df: pd.DataFrame
):
"""
Write metrics to logger.
Parameters
----------
ts:
TSDataset to with backtest data
metrics_df:
Dataframe produced with :py:meth:`etna.pipeline.Pipeline._get_backtest_metrics`
forecast_df:
Forecast from backtest
fold_info_df:
Fold information from backtest
"""
from etna.analysis import plot_backtest_interactive
from etna.datasets import TSDataset
from etna.metrics.utils import aggregate_metrics_df
summary: Dict[str, Any] = dict()
if self.table:
summary["metrics"] = wandb.Table(data=metrics_df)
summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast_df))
summary["fold_info"] = wandb.Table(data=fold_info_df)
if self.plot:
fig = plot_backtest_interactive(forecast_df, ts, history_len=100)
summary["backtest"] = fig
metrics_dict = aggregate_metrics_df(metrics_df)
summary.update(metrics_dict)
self.experiment.log(summary)
[docs] def log_backtest_run(self, metrics: pd.DataFrame, forecast: pd.DataFrame, test: pd.DataFrame):
"""
Backtest metrics from one fold to logger.
Parameters
----------
metrics:
Dataframe with metrics from backtest fold
forecast:
Dataframe with forecast
test:
Dataframe with ground truth
"""
from etna.datasets import TSDataset
from etna.metrics.utils import aggregate_metrics_df
columns_name = list(metrics.columns)
metrics = metrics.reset_index()
metrics.columns = ["segment"] + columns_name
summary: Dict[str, Any] = dict()
if self.table:
summary["metrics"] = wandb.Table(data=metrics)
summary["forecast"] = wandb.Table(data=TSDataset.to_flatten(forecast))
summary["test"] = wandb.Table(data=TSDataset.to_flatten(test))
metrics_dict = aggregate_metrics_df(metrics)
for metric_key, metric_value in metrics_dict.items():
summary[metric_key] = metric_value
self.experiment.log(summary)
[docs] def start_experiment(self, job_type: Optional[str] = None, group: Optional[str] = None, *args, **kwargs):
"""Start experiment.
Complete logger initialization or reinitialize it before the next experiment with the same name.
Parameters
----------
job_type:
Specify the type of run, which is useful when you're grouping runs together
into larger experiments using group.
group:
Specify a group to organize individual runs into a larger experiment.
"""
self.job_type = job_type
self.group = group
self.reinit_experiment()
[docs] def reinit_experiment(self):
"""Reinit experiment."""
self._experiment = wandb.init(
name=self.name,
project=self.project,
entity=self.entity,
group=self.group,
config=self.config,
reinit=True,
tags=self.tags,
job_type=self.job_type,
settings=wandb.Settings(start_method="thread"),
)
[docs] def finish_experiment(self):
"""Finish experiment."""
self._experiment.finish()
@property
def pl_logger(self):
"""Pytorch lightning loggers."""
from pytorch_lightning.loggers import WandbLogger as PLWandbLogger
self._pl_logger = PLWandbLogger(experiment=self.experiment, log_model=self.log_model)
return self._pl_logger
@property
def experiment(self):
"""Init experiment."""
if self._experiment is None:
self.reinit_experiment()
return self._experiment