diff --git a/src/arcs_uncertain_direct.py b/src/arcs_uncertain_direct.py new file mode 100644 index 0000000000000000000000000000000000000000..91b87a3462bcaf07507188bc9387f78a1d7f03e3 --- /dev/null +++ b/src/arcs_uncertain_direct.py @@ -0,0 +1,508 @@ +from functools import reduce +import lightning.pytorch as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from sklearn.preprocessing import StandardScaler +import scipy.stats as stats +import optuna +from pathlib import Path +from typing import List, Dict, Tuple +import argparse +import logging +import sys +from pathlib import Path +import utils.utility as ut +import utils.metrics as mt +from tqdm import tqdm + +from data_tools.data_module import WaVoDataModule + +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + +# Configuration constants +IN_SIZE = 144 +OUT_SIZE = 48 +BATCH_SIZE = 512 +DIFFERENCING = 1 +FLOOD_PERCENTILE = 0.95 +TRAIN_SHARE = 0.7 +VAL_SHARE = 0.15 +#MONITOR = "hp/val_loss" +MONITOR = "val_loss" + +max_epochs = 100 + +# Set up device configuration +if torch.cuda.is_available(): + accelerator = "cuda" + devices = [0] # Or use device selection logic + torch.set_float32_matmul_precision("high") +else: + accelerator = "cpu" + devices = "auto" + +storage_base = "sqlite:///../../../../data-project/KIWaVo/models/optuna/" + + +if ut.debugger_is_active(): + default_storage_name = "debug.db" + n_trials = 2 + max_epochs = 100 + default_max_trials = 1 +else: + default_storage_name = "default.db" + + n_trials = 100 + max_epochs = 2000 + default_max_trials = None + +class CoverageMetricsCallback(pl.callbacks.Callback): + """Calculate coverage metrics for different confidence levels at the end of training.""" + + def __init__(self, confidence_levels: List[float] = [0.95, 0.90, 0.50]): + super().__init__() + self.confidence_levels = confidence_levels + self.z_scores = { + conf: stats.norm.ppf(1 - (1 - conf) / 2) + for conf in confidence_levels + } + chosen_metrics=['mse', 'mae', 'kge', 'rmse', 'r2', 'nse', 'p10', 'p20','wape','conf50','conf10','conf05'] + self.chosen_metrics = mt.get_metric_dict(chosen_metrics) + + + def calculate_coverage(self, trainer, pl_module, dataloader, split_name: str): + """Calculate coverage for a specific data split.""" + #dataloader.dataset.X = dataloader.dataset.X.to(pl_module.device) + #dataloader.dataset.y = dataloader.dataset.y.to(pl_module.device) + pl_module.eval() + + predictions = [] + uncertainties = [] + true_values = [] + initial_levels = [] + # Collect predictions + #for batch in dataloader: + with torch.no_grad(): + for batch in tqdm(dataloader,desc=split_name,total=len(dataloader)): + x, y = batch + mean, std = pl_module(x.to(pl_module.device)) + predictions.append(mean.cpu()) + uncertainties.append(std.cpu()) + true_values.append(y) + initial_levels.append(x[:, -1, pl_module.gauge_idx]) + + predictions = torch.cat(predictions) + uncertainties = torch.cat(uncertainties) + true_values = torch.cat(true_values) + initial_levels = torch.cat(initial_levels) + + mean_gauge = trainer.datamodule.scaler.mean_[trainer.datamodule.gauge_idx] + std_gauge = trainer.datamodule.scaler.scale_[trainer.datamodule.gauge_idx] + initial_levels = initial_levels * std_gauge+ mean_gauge + + + + # Denormalize predictions and uncertainties + mean = trainer.datamodule.scaler.mean_[trainer.datamodule.target_idx] + std = trainer.datamodule.scaler.scale_[trainer.datamodule.target_idx] + + denorm_preds, denorm_stds = denormalize_predictions( + predictions, uncertainties, mean, std + ) + + # Calculate cumulative predictions and uncertainties + cumulative_preds = initial_levels.unsqueeze(1) + torch.cumsum(denorm_preds, dim=1) + cumulative_stds_naiv = torch.sqrt(torch.cumsum(denorm_stds ** 2, dim=1)) + + + # Covariance matrix and cumulative uncertainties taking covariance into account + preds_mean = torch.mean(denorm_preds, dim=0, keepdim=True) + deviations = denorm_preds - preds_mean + covariance_matrix = (deviations.T @ deviations) / deviations.shape[0] + #propagated_std_naive_cov = [covariance_matrix[:i+1,:i+1].sum().sqrt() for i in range(OUT_SIZE)] + cumulative_std_cov = propagate_std_per_timeseries(denorm_stds, covariance_matrix) + + # Calculate true cumulative values + true_cumulative = initial_levels.unsqueeze(1) + torch.cumsum( + true_values * std + mean, dim=1 + ) + flood_mask = torch.any(torch.greater(true_cumulative, trainer.datamodule.threshold), axis=1) + + # Calculate coverage for each confidence level + metrics = {} + std_dict = {"naive": cumulative_stds_naiv, "cov": cumulative_std_cov} + + for m_name, m_func in self.chosen_metrics.items(): + metrics[f'{split_name}_{m_name}'] = torch.nan_to_num(m_func(true_cumulative, cumulative_preds)) + metrics[f'{split_name}_{m_name}_flood'] = torch.nan_to_num(m_func(true_cumulative[flood_mask], cumulative_preds[flood_mask])) + + + for conf_level, z_score in self.z_scores.items(): + for std_type, cumulative_stds in std_dict.items(): + conf_size = z_score * cumulative_stds + lower_bound = cumulative_preds -conf_size + upper_bound = cumulative_preds + conf_size + + coverage = ((lower_bound <= true_cumulative) & (true_cumulative <= upper_bound)).float().mean(dim=0) + metrics[f'{split_name}_{std_type}_coverage_{int(conf_level*100)}'] = coverage + metrics[f'{split_name}_{std_type}_interval_{int(conf_level*100)}'] = conf_size.mean(axis=0) + + coverage_flood = ((lower_bound[flood_mask] <= true_cumulative[flood_mask]) & + (true_cumulative[flood_mask] <= upper_bound[flood_mask])).float().mean(dim=0) + metrics[f'{split_name}_{std_type}_coverage_{int(conf_level*100)}_flood'] = coverage_flood + metrics[f'{split_name}_{std_type}_interval_{int(conf_level*100)}_flood'] = conf_size[flood_mask].mean(axis=0) + + for i in range(OUT_SIZE): + pl_module.logger.log_metrics({k: v[i].item() for k, v in metrics.items()}, step=i+1) + pl_module.train() + + def on_train_start(self, trainer, pl_module): + + metric_placeholders = {s: ut.get_objective_metric( + s) for s in self.chosen_metrics} + metric_placeholders = {**metric_placeholders, ** + {f'{k}_flood': v for k, v in metric_placeholders.items()}} + metric_placeholders = [{f'{cur_set}_{k}': v for k, v in metric_placeholders.items( + )} for cur_set in ['train', 'val', 'test']] + [{'val_loss': 0, 'train_loss': 0}] + metric_placeholders = reduce( + lambda a, b: {**a, **b}, metric_placeholders) + + + metric_placeholders2 = {f"{split}_{std_type}_{metric}_{int(conf_level*100)}{is_flood}": 0.0 + for split in ['train', 'val', 'test'] + for std_type in ['naive', 'cov'] + for conf_level in self.confidence_levels + for metric in ["coverage", "interval"] + for is_flood in ["", "_flood"]} + metric_placeholders2["val_loss"] = 0.0 + metric_placeholders2["train_loss"] = 0.0 + + metric_placeholders = metric_placeholders | metric_placeholders2 + + pl_module.logger.log_hyperparams(pl_module.hparams,metric_placeholders) + + + def on_train_end(self, trainer, pl_module): + """Calculate coverage metrics for all data splits at the end of training.""" + print("Calculating final coverage metrics...") + + # Calculate coverage for each data split + #trainer.datamodule.predict_dataloader() + #self.calculate_coverage(trainer, pl_module, trainer.train_dataloader, 'train') #TODO kaputt? + self.calculate_coverage(trainer, pl_module, trainer.datamodule.predict_dataloader(num_workers=0), 'train') #TODO kaputt? + self.calculate_coverage(trainer, pl_module, trainer.datamodule.val_dataloader(num_workers=0), 'val') + self.calculate_coverage(trainer, pl_module, trainer.datamodule.test_dataloader(num_workers=0), 'test') + + +class WaterLevelPredictorUncertain(pl.LightningModule): + def __init__(self, + feature_count: int, + forecast_horizon: int = 48, + hidden_size_lstm: int = 128, + hidden_size: int = 128, + num_layers_lstm: int = 2, + num_layers: int = 2,#this matches the original model + learning_rate: float = 1e-3, + dropout: float = 0.4, + gauge_idx: int = -1): + super().__init__() + self.save_hyperparameters() + self.gauge_idx = gauge_idx + + self.lstm = nn.LSTM( + input_size=feature_count, + hidden_size=hidden_size_lstm, + num_layers=num_layers_lstm, + dropout=dropout if num_layers_lstm > 1 else 0, + batch_first=True + ) + + self.mean_predictor = nn.Sequential( + nn.Linear(hidden_size_lstm, hidden_size), + nn.Dropout(dropout), + nn.ReLU(), + nn.Linear(hidden_size, forecast_horizon) + ) + + self.uncertainty_predictor = nn.Sequential( + nn.Linear(hidden_size_lstm, hidden_size), + nn.Dropout(dropout), + nn.ReLU(), + nn.Linear(hidden_size, forecast_horizon) + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + lstm_output, _ = self.lstm(x) + last_hidden = lstm_output[:, -1, :] + + mean_prediction = self.mean_predictor(last_hidden) + log_variance = self.uncertainty_predictor(last_hidden) + uncertainty = torch.exp(0.5 * log_variance) + + return mean_prediction, uncertainty + + def training_step(self, batch, batch_idx): + loss = self._calculate_loss(batch) + self.log('train_loss', loss) #TODO + #val_loss = self._common_step(self,batch) + #self.log("hp/val_loss", val_loss, sync_dist=True)# syncdist needed for optuna hyperparameter optimization to work + return loss + + def validation_step(self, batch, batch_idx): + loss = self._calculate_loss(batch) + self.log('val_loss', loss) + return loss + + def _calculate_loss(self, batch): + x, y = batch + mean, std = self(x) + return gaussian_nll_loss(y, mean, std) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + +def propagate_std_per_timeseries(denorm_stds, covariance_matrix): + propagated_stds = torch.zeros_like(denorm_stds) + propagated_stds[:,0] = denorm_stds[:,0] + + for t in range(1, OUT_SIZE): + current_covar = covariance_matrix[:t+1, :t+1] + mask = ~torch.eye(t+1, dtype=torch.bool, device=current_covar.device) + covar_sum = torch.sum(current_covar * mask) + + current_variances = denorm_stds[:, :t+1]**2 + variance = torch.sum(current_variances,dim=1) + variance += 2*covar_sum + propagated_stds[:, t] = torch.sqrt(variance) + + return propagated_stds + + +def gaussian_nll_loss(y_true: torch.Tensor, mean: torch.Tensor, + std: torch.Tensor) -> torch.Tensor: + """Calculate the negative log likelihood loss for Gaussian distribution.""" + return (torch.log(std) + 0.5 * torch.log(2 * torch.tensor(np.pi)) + + 0.5 * ((y_true - mean) / std)**2).mean() + +def denormalize_predictions(normalized_changes: torch.Tensor, + normalized_stds: torch.Tensor, + mean: float, + std: float) -> Tuple[torch.Tensor, torch.Tensor]: + """Denormalize predictions and their uncertainties.""" + scale = torch.tensor(std, device=normalized_changes.device) + mean = torch.tensor(mean, device=normalized_changes.device) + + denormalized_changes = normalized_changes * scale + mean + denormalized_stds = normalized_stds * abs(scale) + + return denormalized_changes, denormalized_stds + + + + + +class Objective: + """ + This class defines the objective function for hyperparameter tuning using Optuna library. + + Args: + filename (str|Path): Path to .csv file, first column should be a timeindex + level_name_org (str): Name of the column with target values + log_dir (str|Path): Path to the logging directory. + """ + + def __init__( + self, + filename, + gauge, + logdir, + monitor=MONITOR, + ): + self.filename = filename + self.level_name_org = gauge + self.log_dir = logdir + self.monitor = monitor + + def _get_callbacks(self):# -> tuple[WaVoCallback, list]: + checkpoint_callback = ModelCheckpoint(save_last=None, monitor=None, save_weights_only=True) + early_stop_callback = EarlyStopping(monitor=self.monitor, mode="min", patience=3) + coverage_callback = CoverageMetricsCallback() #TODO + + + return [checkpoint_callback, + early_stop_callback, + coverage_callback, + ] + + def _get_model_params(self,trial : optuna.Trial): + + + model_params = dict( + learning_rate = trial.suggest_float("lr", 0.001, 0.01), + hidden_size_lstm=trial.suggest_int("hidden_size_lstm", 64, 512), + num_layers_lstm=trial.suggest_int("n_layers_lstm", 1, 2), + hidden_size=trial.suggest_int("hidden_size", 64, 512), + num_layers=2,#trial.suggest_int("n_layers", 2, 4), + dropout=0.4, + ) + + return model_params + + def __call__(self, trial): + model_params = self._get_model_params(trial) + #TODO to seed or not to seed? + pl.seed_everything(42, workers=True) + data_module = WaVoDataModule( + str(self.filename), + self.level_name_org, + in_size=IN_SIZE, + out_size=OUT_SIZE, + batch_size=BATCH_SIZE, + differencing=DIFFERENCING, + percentile=FLOOD_PERCENTILE, + train=TRAIN_SHARE, + val=VAL_SHARE, + model_architecture="classic_lstm", + time_enc='fixed', + ) + + model = WaterLevelPredictorUncertain( + feature_count=data_module.feature_count, + forecast_horizon=OUT_SIZE, + gauge_idx=data_module.gauge_idx, + **model_params + ) + + + logging.info("Params: %s", trial.params) + + + + callbacks = self._get_callbacks() + logger = TensorBoardLogger(self.log_dir, default_hp_metric=False) + + trainer = pl.Trainer( + default_root_dir=self.log_dir, + logger=logger, + accelerator=accelerator, + devices=devices, + callbacks=callbacks, + max_epochs=max_epochs, + log_every_n_steps=10, + ) + trainer.fit(model, data_module) + + # save metrics to optuna + model_path = str(Path(trainer.log_dir).resolve()) + logging.info("model_path: %s", model_path) + trial.set_user_attr("model_path", model_path) + return trainer.callback_metrics[self.monitor].item() + +def parse_args() -> argparse.Namespace: + """Parse all the arguments and provides some help in the command line""" + + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Execute Hyperparameter optimization with optuna and torchlightning." + ) + parser.add_argument( + "filename", metavar="datafile", type=Path, help="The path to your input data." + ) + parser.add_argument( + "gauge", metavar="gaugename", type=str, help="The name of the gauge column." + ) + parser.add_argument( + "logdir", type=Path, help="set a directory for logs and model checkpoints." + ) + parser.add_argument( + "trials", + metavar="trials", + type=int, + default=n_trials, + help="How many trials to run.", + ) + + parser.add_argument( + "--expname", + metavar="experiment_name", + type=str, + default="nameless", + help="The name of the experiment.", + ) + parser.add_argument( + "--storagename", + metavar="storage_name", + type=str, + default=None, + help="The database for the experiment.", + ) + + return parser.parse_args() + +def setup_logging(): + # Logging, add stream handler of stdout to show the messages + logging.basicConfig(level=logging.INFO) + logFormatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S") + consoleHandler = logging.StreamHandler(sys.stdout) + consoleHandler.setFormatter(logFormatter) + optuna.logging.get_logger("optuna").addHandler(consoleHandler) + logging.info( + "Executing %s with device %s and parameters %s ", + sys.argv[0], + devices, + sys.argv[1:], + ) + + + + +def main(): + """Start a hyperparameter optimization with optuna and torchlightning.""" + setup_logging() + parsed_args = parse_args() + file_name = parsed_args.filename + gauge = parsed_args.gauge + log_dir = parsed_args.logdir + + if not log_dir.exists(): + log_dir.mkdir(parents=True) + + + + + study_name = f"{file_name.stem} {parsed_args.expname}" # Unique identifier of the study. + storage_name = ( + f"{storage_base}{default_storage_name}" + if parsed_args.storagename is None + else f"{storage_base}{parsed_args.storagename}.db" + ) + + + objective = Objective(file_name,gauge,log_dir) + sampler = optuna.samplers.RandomSampler(seed=42) + + study = optuna.create_study( + storage=storage_name, + sampler=sampler, + study_name=study_name, + direction="minimize", + load_if_exists=True, + ) + + study.set_metric_names([MONITOR]) + + study.optimize( + objective, + n_trials=parsed_args.trials, + gc_after_trial=True, + timeout=None, + callbacks=[lambda study, trial: torch.cuda.empty_cache()], + ) + + +if __name__ == "__main__": + main() + # python main_hyper.py ../../../../data-project/KIWaVo/data/input/mergedPreetzAll.csv Preetz_pegel_cm ../../../../data-project/KIWaVo/models/lfu/preetz/ 100 --expname lstms_1 --storagename lfu diff --git a/src/data_tools/data_module.py b/src/data_tools/data_module.py index 1fab9918fea75d6d52ef6bc7a0b55ce9cb1671bf..a062a767f34da819195b6822d4c5521171e44cba 100644 --- a/src/data_tools/data_module.py +++ b/src/data_tools/data_module.py @@ -159,18 +159,18 @@ class WaVoDataModule(pl.LightningDataModule): - def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, num_workers=8,persistent_workers=True) #TODO test persistent workers + def train_dataloader(self,num_workers=8) -> DataLoader: #TODO test if this doesnt crash stuff somewhere + return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=True, num_workers=num_workers,persistent_workers=num_workers>0) #TODO test persistent workers - def val_dataloader(self) -> DataLoader: - return DataLoader(self.val_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=8,persistent_workers=True) + def val_dataloader(self,num_workers=8) -> DataLoader: + return DataLoader(self.val_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=num_workers,persistent_workers=num_workers>0) - def test_dataloader(self) -> DataLoader: - return DataLoader(self.test_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=8,persistent_workers=True) + def test_dataloader(self,num_workers=8) -> DataLoader: + return DataLoader(self.test_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=num_workers,persistent_workers=num_workers>0) - def predict_dataloader(self)-> DataLoader: + def predict_dataloader(self,num_workers=8)-> DataLoader: # return the trainset, but sorted - return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=8,persistent_workers=True) + return DataLoader(self.train_set, batch_size=self.hparams.batch_size, shuffle=False, num_workers=num_workers,persistent_workers=num_workers>0) def get_edge_indexes(self, start:str=None,end:str=None,mode:str='single') -> Tuple[pd.DatetimeIndex,int,int]: """Return the first value of an index for the prediction and the start/end index for the input data. diff --git a/src/main_hyper.py b/src/main_hyper.py index 015ff821af0fe3125a2c4708b725411e42e944fe..2777e32a2b921a12e95b69e729e8696dff842402 100644 --- a/src/main_hyper.py +++ b/src/main_hyper.py @@ -13,7 +13,6 @@ import optuna import torch from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger -from lightning.pytorch.tuner import Tuner import utils.utility as ut from utils.callbacks import OptunaPruningCallback, WaVoCallback @@ -40,7 +39,7 @@ if ut.debugger_is_active(): "sqlite:///../../../data-project/KIWaVo/models/optuna/debug_02.db" ) n_trials = 2 - max_epochs = 2 + max_epochs = 10 default_max_trials = 1 else: # default_storage_name = 'sqlite:///../../models_torch/optuna/optimization_01.db' @@ -78,6 +77,8 @@ else: default_batch_size = 2048 + + class Objective: """ This class defines the objective function for hyperparameter tuning using Optuna library. @@ -140,7 +141,7 @@ class Objective: def _get_model_params(self,trial : optuna.Trial): #differencing = 0 - differencing = trial.suggest_int("differencing", 1, 1) + differencing = trial.suggest_int("differencing", 0, 1) learning_rate = trial.suggest_float("lr", 0.00001, 0.01) # optimizer = trial.suggest_categorical("optimizer", ["adam","adamw"]) optimizer = trial.suggest_categorical("optimizer", ["adam"]) @@ -153,7 +154,10 @@ class Objective: #model_architecture = trial.suggest_categorical("model_architecture", ["tsmixer"]) #model_architecture = trial.suggest_categorical("model_architecture", ["chained_dense"]) #model_architecture = trial.suggest_categorical("model_architecture", ["classic_lstm","last_lstm"]) - model_architecture = trial.suggest_categorical("model_architecture", ["classic_lstm"]) + #model_architecture = trial.suggest_categorical("model_architecture", ["classic_lstm"]) + #model_architecture = trial.suggest_categorical("model_architecture", ["xlstm"]) + model_architecture = trial.suggest_categorical("model_architecture", ["lstm2"]) + #model_architecture = trial.suggest_categorical("model_architecture", ["classic_lstm","new_lstm"]) if model_architecture in ["chained_dense"]: @@ -220,6 +224,27 @@ class Objective: ), # uneven numbers only ) + #if model_architecture in ["xlstm"]: + # model_params = dict() + + if model_architecture == "lstm2": + model_params = dict( + hidden_size=trial.suggest_int("hidden_size", 32, 512), + num_layers=trial.suggest_int("n_layers", 2, 4), + dropout=trial.suggest_float("dropout", 0, 0.5), + bidirectional=trial.suggest_categorical("bidirectional", [True, False]), + use_batchnorm=trial.suggest_categorical("use_batchnorm", [True, False]), + use_layernorm=trial.suggest_categorical("use_layernorm", [True, False]), + activation=trial.suggest_categorical("activation", ["relu", "gelu"]), + #output_activation=trial.suggest_categorical("output_activation", ["relu", "gelu"]), + output_activation=None, + ) + if model_params["bidirectional"]: + model_params["use_residual"] =False + else: + model_params["use_residual"] =trial.suggest_categorical("use_residual", [True, False]) + + return model_architecture, differencing, learning_rate,optimizer,embed_time, model_params def __call__(self, trial): @@ -247,6 +272,10 @@ class Objective: ) # prepare model + #m = myxLSTM(target_idx= data_module.target_idx) + + + model = WaVoLightningModule( model_architecture=model_architecture, feature_count=data_module.feature_count, @@ -312,8 +341,8 @@ class Objective: for i in [23, 47]: trial.set_user_attr(f'{metric}_{i+1}', my_callback.metrics[metric][i].item()) - #return my_callback.metrics["hp/val_mae"].mean().item() - return my_callback.metrics[self.monitor].item() + return my_callback.metrics["hp/val_mae"].mean().item() + #return my_callback.metrics[self.monitor].item() def parse_args() -> argparse.Namespace: diff --git a/src/models/lightning_module.py b/src/models/lightning_module.py index be5f30f7260fc1a6b6a736853986a6b41cb448de..bbecdafca1eb1f3502406eea6c3bd42187e2b16c 100644 --- a/src/models/lightning_module.py +++ b/src/models/lightning_module.py @@ -17,9 +17,10 @@ from models.Autoformer import Model as Autoformer from models.chained_dense import ChainedModel from models.DLinear import Model as DLinear from models.Informer import Model as Informer -from models.lstms import LSTM, LSTM_LAST +from models.lstms import LSTM, LSTM_LAST, LSTMv2 from models.Transformer import Model as Transformer from models.tsmixer import TSMixer +#from models.xlstms import myxLSTM # pylint: disable=import-error import utils.utility as ut @@ -231,3 +232,7 @@ class WaVoLightningModule(pl.LightningModule): return TSMixer(feature_count=self.feature_count,in_size=self.in_size,out_size=self.out_size,target_slice=self.target_idx,**kwargs) if self.model_architecture == "chained_dense": return ChainedModel(feature_count=self.feature_count,in_size=self.in_size,out_size=self.out_size,**kwargs) + #if self.model_architecture == "xlstm": + # return myxLSTM(feature_count=self.feature_count,in_size=self.in_size,out_size=self.out_size,target_idx=self.target_idx,**kwargs) + if self.model_architecture == "lstm2": + return LSTMv2(feature_count=self.feature_count,in_size=self.in_size,out_size=self.out_size,**kwargs) diff --git a/src/models/lstms.py b/src/models/lstms.py index ba855b50362eb23009589cf0f3112b9ce8f34c08..2fe2d3dd0f237d4fc11c7ae26d00de378347420d 100644 --- a/src/models/lstms.py +++ b/src/models/lstms.py @@ -1,4 +1,6 @@ +import torch import torch.nn as nn +import torch.nn.functional as F class LSTM(nn.Module): """LSTM Network @@ -114,3 +116,194 @@ class LSTM_LAST(nn.Module): out = self.mlp(output[:,-1]) return out + + + +class LSTMv2(nn.Module): + """ + A highly adaptable LSTM model that takes multivariate input and returns univariate output. + + Features: + - Configurable LSTM layers + - Batch normalization + - Dropout + - Residual connections + - Layer normalization + - Configurable activation functions + - Dense output layers + + Args: + in_features (int): Number of input features (multivariate input size) + seq_length (int): Length of input sequence + out_size (int): Length of output sequence + hidden_size (int): Size of LSTM hidden state + num_layers (int): Number of LSTM layers + dropout (float): Dropout probability + bidirectional (bool): Whether to use bidirectional LSTM + use_batchnorm (bool): Whether to use batch normalization + use_layernorm (bool): Whether to use layer normalization + use_residual (bool): Whether to use residual connections + activation (str): Activation function to use ('relu', 'tanh', 'leakyrelu', 'gelu') + output_activation (str or None): Final activation function (None for linear output) + """ + + def __init__( + self, + feature_count, + in_size, + out_size, + hidden_size=128, + num_layers=2, + dropout=0.2, + bidirectional=False, + use_batchnorm=True, + use_layernorm=False, + use_residual=True, + activation='relu', + output_activation=None + ): + + #def __init__(self, feature_count, in_size=144, out_size=48, hidden_size_lstm=128, hidden_size=64,num_layers_lstm=2, dropout=0.2, num_layers=2): + + super(LSTMv2, self).__init__() + + self.feature_count = feature_count + self.in_size = in_size + self.out_size = out_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirectional = bidirectional + self.use_batchnorm = use_batchnorm + self.use_layernorm = use_layernorm + self.use_residual = use_residual + self.dropout_prob = dropout + + assert not (use_residual and bidirectional) + # Directional factor (2 if bidirectional, else 1) + self.d_factor = 2 if bidirectional else 1 + + # LSTM Layer + self.lstm = nn.LSTM( + input_size=feature_count, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0, + bidirectional=bidirectional + ) + + # Batch Normalization for LSTM output + if use_batchnorm: + self.bn_lstm = nn.BatchNorm1d(self.hidden_size * self.d_factor) + + + # Layer Normalization for LSTM output + if use_layernorm: + self.ln_lstm = nn.LayerNorm([in_size, hidden_size * self.d_factor]) + + # Dropout layer + self.dropout = nn.Dropout(dropout) + + # Dense layers for output projection + self.dense1 = nn.Linear(hidden_size * self.d_factor, hidden_size) + + if use_batchnorm: + #self.bn_dense = nn.BatchNorm1d(in_size) + self.bn_dense = nn.BatchNorm1d(hidden_size) + + if use_layernorm: + self.ln_dense = nn.LayerNorm([in_size, hidden_size]) + + # Final output layer + self.output_layer = nn.Linear(hidden_size, out_size) + + # Set activation function + if activation == 'relu': + self.activation = F.relu + elif activation == 'tanh': + self.activation = torch.tanh + elif activation == 'leakyrelu': + self.activation = F.leaky_relu + elif activation == 'gelu': + self.activation = F.gelu + else: + self.activation = F.relu # Default to ReLU + + # Set output activation function + if output_activation == 'sigmoid': + self.output_activation = torch.sigmoid + elif output_activation == 'tanh': + self.output_activation = torch.tanh + elif output_activation == 'softmax': + self.output_activation = lambda x: F.softmax(x, dim=-1) + else: + self.output_activation = None # Linear output + + def forward(self, x): + """ + Forward pass of the LSTM model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_length, in_features) + + Returns: + torch.Tensor: Output tensor of shape (batch_size, out_size) + """ + batch_size = x.size(0) + + # Initialize hidden state and cell state + h0 = torch.zeros(self.num_layers * self.d_factor, batch_size, self.hidden_size).to(x.device) + c0 = torch.zeros(self.num_layers * self.d_factor, batch_size, self.hidden_size).to(x.device) + + # LSTM forward pass + lstm_out, _ = self.lstm(x, (h0, c0)) + + # Store original output for residual connection if needed + original_lstm_out = lstm_out + + # Apply batch normalization if specified + if self.use_batchnorm: + # Reshape for batch norm which expects (batch, channels, seq_len) + lstm_out = lstm_out.permute(0, 2, 1) + lstm_out = self.bn_lstm(lstm_out) + lstm_out = lstm_out.permute(0, 2, 1) + + # Apply layer normalization if specified + if self.use_layernorm: + lstm_out = self.ln_lstm(lstm_out) + + # Apply dropout + lstm_out = self.dropout(lstm_out) + + # First dense layer + dense_out = self.dense1(lstm_out) + dense_out = self.activation(dense_out) + + # Apply batch normalization on dense output if specified + if self.use_batchnorm: + dense_out = dense_out.permute(0, 2, 1) + dense_out = self.bn_dense(dense_out) + dense_out = dense_out.permute(0, 2, 1) + + # Apply layer normalization on dense output if specified + if self.use_layernorm: + dense_out = self.ln_dense(dense_out) + + # Apply residual connection if specified + if self.use_residual and self.hidden_size == self.hidden_size * self.d_factor: + dense_out = dense_out + original_lstm_out + + # Apply dropout again + dense_out = self.dropout(dense_out) + + # Final output layer + output = self.output_layer(dense_out) + + # Get the last time step or reshape as needed + output = output[:, -1, :] # Get the last time step of each sequence + + # Apply output activation if specified + if self.output_activation is not None: + output = self.output_activation(output) + + return output \ No newline at end of file diff --git a/src/utils/utility.py b/src/utils/utility.py index 46341c9aebbf1983497dc5576f3b8d98f7f3ffa1..90d531dcaf54f0beb0740755ef39e3ce964a6a7e 100644 --- a/src/utils/utility.py +++ b/src/utils/utility.py @@ -187,7 +187,7 @@ def debugger_is_active() -> bool: return hasattr(sys, 'gettrace') and sys.gettrace() is not None def need_classic_input(s :str) -> bool: - return s in ["classic_lstm","last_lstm","tsmixer","chained_dense","ensemble"] or s is None + return s in ["classic_lstm","last_lstm","tsmixer","chained_dense","ensemble","xlstm","lstm2"] or s is None def encode_time(df_stamp,encoding='fixed'): if encoding=='timeF':