diff --git a/src/utils/callbacks.py b/src/utils/callbacks.py index 61368f9d5ff53063e292f4c65449701e36d0f333..5c6a7cd27ce1c40918e6af74b0f7addedc7bec28 100644 --- a/src/utils/callbacks.py +++ b/src/utils/callbacks.py @@ -33,7 +33,7 @@ class WaVoCallback(Callback): """ # pylint: disable-next=dangerous-default-value - def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape'],ensemble=False): + def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape','conf50','conf10','conf05'],ensemble=False): super().__init__() self.chosen_metrics = mt.get_metric_dict(chosen_metrics) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index ba828908ca08992164d90e7da54da7b61d5021dd..bb7ca90ec185a86dce32cdeb1731e3bee438c96e 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -6,6 +6,30 @@ import torch from torch import Tensor +def conformal_series(y_true: Tensor,y_pred:Tensor,alpha:float = 0.95) -> Tensor: + """ + Calculate conformal prediction scores for multiple sets of predictions. + + Parameters: + ----------- + y_true : torch.Tensor + True target values + y_pred : torch.Tensor + Predicted values + alpha : float, optional (default=0.05) + 1 - Desired confidence level (between 0 and 1) + Returns: + -------- + conformity_scores : torch.Tensor + The calculated conformity scores for each of the 48 sets + """ + + confidence = 1 - alpha + residuals = torch.abs(y_true - y_pred) + conformity_scores = torch.quantile(residuals, confidence, dim=0) + return conformity_scores + + def nse_series(y_true: Tensor, y_pred: Tensor) -> Tensor: """ Calculates NSE for all prediction steps, @@ -187,7 +211,11 @@ def get_metric_dict(metric_name_list) -> dict: 'p20': p20_series, 'rmse': rmse_series, 'r2': r_squared_series, - 'wape': wape_series} + 'wape': wape_series, + 'conf50': lambda x, y: conformal_series(x, y, 0.5), + 'conf10': lambda x, y: conformal_series(x, y, 0.1), + 'conf05': lambda x, y: conformal_series(x, y, 0.05) + } metrics = {} for el in metric_name_list: