From e3bf0ac3b301165f865ad9d5184b2e3c936b5315 Mon Sep 17 00:00:00 2001
From: Michel Spils <msp@informatik.uni-kiel.de>
Date: Thu, 12 Dec 2024 16:48:55 +0000
Subject: [PATCH] conformal prediction callback

---
 src/utils/callbacks.py |  2 +-
 src/utils/metrics.py   | 30 +++++++++++++++++++++++++++++-
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/src/utils/callbacks.py b/src/utils/callbacks.py
index 61368f9..5c6a7cd 100644
--- a/src/utils/callbacks.py
+++ b/src/utils/callbacks.py
@@ -33,7 +33,7 @@ class WaVoCallback(Callback):
     """
 
     # pylint: disable-next=dangerous-default-value
-    def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape'],ensemble=False):
+    def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape','conf50','conf10','conf05'],ensemble=False):
 
         super().__init__()
         self.chosen_metrics = mt.get_metric_dict(chosen_metrics)
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
index ba82890..bb7ca90 100644
--- a/src/utils/metrics.py
+++ b/src/utils/metrics.py
@@ -6,6 +6,30 @@ import torch
 from torch import Tensor
 
 
+def conformal_series(y_true: Tensor,y_pred:Tensor,alpha:float = 0.95) -> Tensor:
+    """
+    Calculate conformal prediction scores for multiple sets of predictions.
+    
+    Parameters:
+    -----------
+    y_true : torch.Tensor
+        True target values
+    y_pred : torch.Tensor
+        Predicted values
+    alpha : float, optional (default=0.05)
+        1 - Desired confidence level (between 0 and 1)
+    Returns:
+    --------
+    conformity_scores : torch.Tensor
+        The calculated conformity scores for each of the 48 sets
+    """
+    
+    confidence = 1 - alpha
+    residuals = torch.abs(y_true - y_pred)
+    conformity_scores = torch.quantile(residuals, confidence, dim=0)
+    return conformity_scores
+
+
 def nse_series(y_true: Tensor, y_pred: Tensor) -> Tensor:
     """
     Calculates NSE for all prediction steps,
@@ -187,7 +211,11 @@ def get_metric_dict(metric_name_list) -> dict:
                       'p20': p20_series,
                       'rmse': rmse_series,
                       'r2': r_squared_series,
-                      'wape': wape_series}
+                      'wape': wape_series,
+                      'conf50': lambda x, y: conformal_series(x, y, 0.5),
+                        'conf10': lambda x, y: conformal_series(x, y, 0.1),
+                        'conf05': lambda x, y: conformal_series(x, y, 0.05)
+                      }
 
     metrics = {}
     for el in metric_name_list:
-- 
GitLab