From e3bf0ac3b301165f865ad9d5184b2e3c936b5315 Mon Sep 17 00:00:00 2001 From: Michel Spils <msp@informatik.uni-kiel.de> Date: Thu, 12 Dec 2024 16:48:55 +0000 Subject: [PATCH] conformal prediction callback --- src/utils/callbacks.py | 2 +- src/utils/metrics.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/utils/callbacks.py b/src/utils/callbacks.py index 61368f9..5c6a7cd 100644 --- a/src/utils/callbacks.py +++ b/src/utils/callbacks.py @@ -33,7 +33,7 @@ class WaVoCallback(Callback): """ # pylint: disable-next=dangerous-default-value - def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape'],ensemble=False): + def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape','conf50','conf10','conf05'],ensemble=False): super().__init__() self.chosen_metrics = mt.get_metric_dict(chosen_metrics) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index ba82890..bb7ca90 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -6,6 +6,30 @@ import torch from torch import Tensor +def conformal_series(y_true: Tensor,y_pred:Tensor,alpha:float = 0.95) -> Tensor: + """ + Calculate conformal prediction scores for multiple sets of predictions. + + Parameters: + ----------- + y_true : torch.Tensor + True target values + y_pred : torch.Tensor + Predicted values + alpha : float, optional (default=0.05) + 1 - Desired confidence level (between 0 and 1) + Returns: + -------- + conformity_scores : torch.Tensor + The calculated conformity scores for each of the 48 sets + """ + + confidence = 1 - alpha + residuals = torch.abs(y_true - y_pred) + conformity_scores = torch.quantile(residuals, confidence, dim=0) + return conformity_scores + + def nse_series(y_true: Tensor, y_pred: Tensor) -> Tensor: """ Calculates NSE for all prediction steps, @@ -187,7 +211,11 @@ def get_metric_dict(metric_name_list) -> dict: 'p20': p20_series, 'rmse': rmse_series, 'r2': r_squared_series, - 'wape': wape_series} + 'wape': wape_series, + 'conf50': lambda x, y: conformal_series(x, y, 0.5), + 'conf10': lambda x, y: conformal_series(x, y, 0.1), + 'conf05': lambda x, y: conformal_series(x, y, 0.05) + } metrics = {} for el in metric_name_list: -- GitLab