Skip to content
Snippets Groups Projects
Commit e3bf0ac3 authored by Michel Spils's avatar Michel Spils
Browse files

conformal prediction callback

parent b47e4d15
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,7 @@ class WaVoCallback(Callback):
"""
# pylint: disable-next=dangerous-default-value
def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape'],ensemble=False):
def __init__(self, chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape','conf50','conf10','conf05'],ensemble=False):
super().__init__()
self.chosen_metrics = mt.get_metric_dict(chosen_metrics)
......
......@@ -6,6 +6,30 @@ import torch
from torch import Tensor
def conformal_series(y_true: Tensor,y_pred:Tensor,alpha:float = 0.95) -> Tensor:
"""
Calculate conformal prediction scores for multiple sets of predictions.
Parameters:
-----------
y_true : torch.Tensor
True target values
y_pred : torch.Tensor
Predicted values
alpha : float, optional (default=0.05)
1 - Desired confidence level (between 0 and 1)
Returns:
--------
conformity_scores : torch.Tensor
The calculated conformity scores for each of the 48 sets
"""
confidence = 1 - alpha
residuals = torch.abs(y_true - y_pred)
conformity_scores = torch.quantile(residuals, confidence, dim=0)
return conformity_scores
def nse_series(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""
Calculates NSE for all prediction steps,
......@@ -187,7 +211,11 @@ def get_metric_dict(metric_name_list) -> dict:
'p20': p20_series,
'rmse': rmse_series,
'r2': r_squared_series,
'wape': wape_series}
'wape': wape_series,
'conf50': lambda x, y: conformal_series(x, y, 0.5),
'conf10': lambda x, y: conformal_series(x, y, 0.1),
'conf05': lambda x, y: conformal_series(x, y, 0.05)
}
metrics = {}
for el in metric_name_list:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment