Skip to content
Snippets Groups Projects
train_ensemble.py 10.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • Michel Spils's avatar
    Michel Spils committed
    import argparse
    import logging
    import pickle
    import sys
    from pathlib import Path
    
    import lightning.pytorch as pl
    import optuna
    import torch
    import yaml
    from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
    from lightning.pytorch.loggers import TensorBoardLogger
    
    import utils.helpers as hp
    import utils.utility as ut
    from utils.callbacks import OptunaPruningCallback, WaVoCallback
    from data_tools.data_module import WaVoDataModule
    from models.ensemble_models import WaVoLightningEnsemble, WaVoLightningAttentionEnsemble
    
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        accelerator = "cuda"
        torch.set_float32_matmul_precision("high")
    
        max_free = 0
        best_device = None
        for j in range(torch.cuda.device_count()):
            free, _ = torch.cuda.mem_get_info(j)
            if free > max_free:
                max_free = free
                best_device = j
        devices = [best_device]
    
    else:
        accelerator = "cpu"
        devices = "auto"
    
    storage_base = "sqlite:///../../../../data-project/KIWaVo/models/optuna/"
    
    if ut.debugger_is_active():
        max_epochs = 2
        default_storage_name = "sqlite:///../../../data-project/KIWaVo/models/optuna/icaart_ensemble_debug_01.db"
    
    else:
        max_epochs = 2000
        default_storage_name = (
            "sqlite:///../../../../data-project/KIWaVo/models/optuna/icaart_ensemble_01.db"
        )
    
    
    class Objective:
        """
        This class defines the objective function for hyperparameter tuning using Optuna library.
    
        Args:
            filename (str|Path): Path to .csv file, first column should be a timeindex
            model_dir (str|Path): Path to the directory containing the base models
            log_dir (str|Path): Path to the logging directory
            select_strat (str): How to select the base models (random or first n)
            model_count (int): How many base models to use
            monitor (str, optional): metric to monitor. Defaults to 'hp/val_loss'.
        """
    
        def __init__(
            self,
            filename,
            modeldir,
            logdir,
            select_strat,
            model_count,
            monitor="hp/val_loss",
            **kwargs,
        ):
            # Hold these implementation specific arguments as the fields of the class.
            self.filename = filename
            self.model_dir = modeldir
            self.log_dir = logdir
            self.select_strat = select_strat
            self.model_count = model_count
            self.monitor = monitor
    
        def _get_callbacks(self, trial):
            checkpoint_callback = ModelCheckpoint(
                save_top_k=1, monitor=self.monitor, save_weights_only=True
            )
            pruning_callback = OptunaPruningCallback(trial, monitor=self.monitor)
            early_stop_callback = EarlyStopping(
                monitor=self.monitor, mode="min", patience=3
            )
            my_callback = WaVoCallback(ensemble=True)
    
            return my_callback, [
                checkpoint_callback,
                pruning_callback,
                early_stop_callback,
                my_callback,
            ]
    
        def _get_model_params(self, trial):
            model_params = dict(
                hidden_size=trial.suggest_int("hidden_size", 32, 512),
                num_layers=trial.suggest_int("n_layers", 2, 4),
                dropout=0.25,
                learning_rate=trial.suggest_float("lr", 0.00001, 0.01),
                norm_func=trial.suggest_categorical("norm_func", ["softmax", "minmax"]),
            )
    
            return model_params
    
        def __call__(self, trial):
    
            model_params = self._get_model_params(trial)
            # log_dir = '../../../data-project/KIWaVo/models/ensemble_debug/'
            # model_dir = Path('../../../data-project/KIWaVo/models/icaarts_hollingstedt/lightning_logs/')
            model_list = []
            model_path_list = []
            yaml_data = None
    
            if self.select_strat == "random":
                all_models = [x.name.split("_")[1] for x in self.model_dir.iterdir()]
    
            elif self.select_strat == "first":
                model_choice = list(range(self.model_count))
            
            elif self.select_strat == "hardcoded":
                model_choice = [37,88,106,110,116,137,171,175,181,186]
            else:
                raise ValueError("Invalid selection strategy")
    
            # for s in [0,1,2]:
            for s in model_choice:
                temp_dir = self.model_dir / f"version_{s}/"
    
                model_list.append(
                    hp.load_model_cuda(temp_dir, use_cuda=use_cuda, devices=devices)
                )
                model_path_list.append(str(temp_dir.resolve()))
    
                if yaml_data is None:
                    yaml_data = hp.load_settings_model(temp_dir)
                    # with open(temp_dir / 'hparams.yaml', 'r') as file:
                    #    yaml_data = yaml.load(file, Loader=yaml.FullLoader)
                    #    yaml_data['scaler'] = pickle.loads(yaml_data['scaler'])
    
            config = {
                "scaler": yaml_data[
                    "scaler"
                ],  # TODO test how this works without giving scaler etc. outside of jupyterlab
                #'filename' : yaml_data['filename'],
                "filename": str(self.filename),
                "level_name_org": yaml_data["level_name_org"],
                "out_size": yaml_data["out_size"],
                "threshold": yaml_data["threshold"],
                "feature_count": yaml_data["feature_count"],
                "differencing": yaml_data["differencing"],
                "model_architecture": "ensemble",
            }
    
            print("ACHTUNG GGF. FALSCHES MODELL")
            ensemble_model = WaVoLightningEnsemble(model_list,model_path_list,**model_params)#TODO 'ÄNDERN!
            #ensemble_model = WaVoLightningAttentionEnsemble(
            #    model_list, model_path_list, **model_params
            #)
    
            config["in_size"] = ensemble_model.max_in_size
    
            data_module = WaVoDataModule(**config)
    
            logging.info("Params: %s", trial.params)
    
            my_callback, callbacks = self._get_callbacks(trial)
            logger = TensorBoardLogger(self.log_dir, default_hp_metric=False)
    
            trainer = pl.Trainer(
                default_root_dir=self.log_dir,
                gradient_clip_val=0.5,
                logger=logger,
                accelerator=accelerator,
                devices=devices,
                callbacks=callbacks,
                max_epochs=max_epochs,
                log_every_n_steps=10,
            )
            trainer.fit(ensemble_model, data_module)
    
            # save metrics to optuna
            model_path = str(Path(trainer.log_dir).resolve())
            logging.info("model_path: %s", model_path)
            trial.set_user_attr("model_path", model_path)
            for metric in ["hp/val_nse", "hp/val_mae", "hp/val_mae_flood"]:
                for i in [23, 47]:
                    trial.set_user_attr(
                        f"{metric}_{i}", my_callback.metrics[metric][i].item()
                    )
    
            return my_callback.metrics[self.monitor].item()
    
    
    def parse_args() -> argparse.Namespace:
        """Parse all the arguments and provides some help in the command line"""
    
        parser: argparse.ArgumentParser = argparse.ArgumentParser(
            description="Execute experiments for exp_icaart."
        )
        parser.add_argument(
            "filename", metavar="datafile", type=Path, help="The path to your input data."
        )
        parser.add_argument(
            "modeldir", metavar="modeldir", type=Path, help="The path to your base models."
        )
        parser.add_argument(
            "logdir", type=Path, help="set a directory for logs and model checkpoints."
        )
        parser.add_argument(
            "trials",
            metavar="trials",
            type=int,
            default=100,
            help="How many trials to run.",
        )
        parser.add_argument("select_strat", choices=["random", "first","hardcoded"], help="How to select the base models.")
        parser.add_argument(
            "model_count",
            metavar="mc",
            type=int,
            default=5,
            help="How many base models to use.",
        )
    
        parser.add_argument(
            "--expname",
            metavar="experiment_name",
            type=str,
            default="nameless",
            help="The name of the experiment.",
        )
        parser.add_argument(
            "--storagename",
            metavar="storage_name",
            type=str,
            default=None,
            help="The database for the experiment.",
        )
    
        return parser.parse_args()
    
    
    def main():
        parsed_args = parse_args()
        if not parsed_args.logdir.exists():
            parsed_args.logdir.mkdir(parents=True)
    
        if False:
            pruner = optuna.pruners.HyperbandPruner(
                min_resource=1, max_resource="auto", reduction_factor=3, bootstrap_count=0
            )
        else:
            pruner = optuna.pruners.NopPruner()
    
        study_name = f"{parsed_args.filename.stem} {parsed_args.expname}"  # Unique identifier of the study.
        storage_name = (
            default_storage_name
            if parsed_args.storagename is None
            else f"{storage_base}{parsed_args.storagename}.db"
        )
    
        # Logging, add stream handler of stdout to show the messages
        logging.basicConfig(level=logging.INFO)
        logFormatter = logging.Formatter(
            "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S"
        )
        fileHandler = logging.FileHandler(
            parsed_args.logdir / "ensemble_hyper.log",
        )
        consoleHandler = logging.StreamHandler(sys.stdout)
        fileHandler.setFormatter(logFormatter)
        consoleHandler.setFormatter(logFormatter)
        logging.getLogger().addHandler(fileHandler)
        # logging.getLogger().addHandler(consoleHandler)
        optuna.logging.get_logger("optuna").addHandler(fileHandler)
        optuna.logging.get_logger("optuna").addHandler(consoleHandler)
    
        logging.info(
            "Start of this execution======================================================================"
        )
        logging.info(
            "Executing %s with device %s and parameters %s ",
            sys.argv[0],
            devices,
            sys.argv[1:],
        )
    
        study = optuna.create_study(
            study_name=study_name,
            storage=storage_name,
            direction="minimize",
            pruner=pruner,
            load_if_exists=True,
        )
    
        study.set_metric_names(["hp/val_loss"])
        objective = Objective(**vars(parsed_args), gc_after_trial=True)
    
        study.optimize(
            objective,
            n_trials=parsed_args.trials,
            timeout=None,
            callbacks=[lambda study, trial: torch.cuda.empty_cache()],
        )
    
    
    if __name__ == "__main__":
        # TODO command line arguments
        main()