diff --git a/src/app.py b/src/app.py index 860c4aa49d77bc3a21774b51ceb6fe7c28e68a71..002888ae3e459a92d3030484a5962b9eb2d65b89 100644 --- a/src/app.py +++ b/src/app.py @@ -22,10 +22,8 @@ from dashboard.constants import NUM_RECENT_FORECASTS, NUM_RECENT_LOGS, BUFFER_TI from utils.db_tools.orm_classes import (InputForecasts, Log, Metric, Modell, ModellSensor, PegelForecasts, SensorData) -import utils.utility as ut - -debug = True #ut.debugger_is_active() +debug = False #TODO add logging to dashboard? class ForecastMonitor: @@ -411,7 +409,7 @@ class ForecastMonitor: print(f"Error getting model metrics: {str(e)}") return None, None - def get_confidence(self, model_id, alpha='10',flood=False,subset='test'): + def get_confidence(self, model_id,flood=False,subset='test'): """Get metrics for a specific model. Assumes that each model only has on set of metrics with the same name""" try: @@ -528,7 +526,7 @@ class ForecastMonitor: sensor_names = list(df_historical.columns) df_inp_fcst,ext_forecast_names = self.get_input_forecasts(sensor_names) df_forecasts = self.get_recent_forecasts(actual_model_name) - df_conf = self.get_confidence(model_id, alpha='10',flood=False,subset='test') + df_conf = self.get_confidence(model_id,flood=False,subset='test') # Format logs diff --git a/src/models/ensemble_models.py b/src/models/ensemble_models.py index 92028bcd913a54c33ac61cfbe7b22e25e569607d..55364e877122d81238758e416bceef6faf02bec6 100644 --- a/src/models/ensemble_models.py +++ b/src/models/ensemble_models.py @@ -131,15 +131,25 @@ class WaVoLightningEnsemble(pl.LightningModule): def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate): super().__init__() - self.model_list = model_list - self.max_in_size = max([model.hparams['in_size'] for model in self.model_list]) - self.out_size = model_list[0].hparams['out_size'] - self.feature_count = model_list[0].hparams['feature_count'] - - assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx" - self.target_idx = model_list[0].target_idx + + if model_list is None: + model_list = [] + self.model_list = model_list + self.max_in_size = 144 + self.out_size = 48 + self.feature_count = 1 + self.target_idx = 0 + self.scaler = None + else: + self.model_list = model_list + self.max_in_size = max([model.hparams['in_size'] for model in self.model_list]) + self.out_size = model_list[0].hparams['out_size'] + self.feature_count = model_list[0].hparams['feature_count'] + self.target_idx = model_list[0].target_idx + self.scaler = model_list[0].scaler + assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx" self.model_architecture = 'ensemble' - self.scaler = model_list[0].scaler + #self.weighting = WeightingModelAttention( self.weighting = WeightingModel( diff --git a/src/predict_database.py b/src/predict_database.py index 5ef0d5b27fd49f6f0a940b98ca6b69da81a90983..8ffbfd825512566abac338e60bdbce67866fed15 100644 --- a/src/predict_database.py +++ b/src/predict_database.py @@ -42,7 +42,7 @@ def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]: main_config["end"] = pd.to_datetime(main_config["end"]) if main_config.get("fake_external"): - assert main_config["range"] == True, "fake_external only works with range=True" + assert main_config["range"], "fake_external only works with range=True" if passed_args.start is not None: main_config["start"] = passed_args.start @@ -57,7 +57,13 @@ def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]: #assert all([isinstance(k,str) and isinstance(v,str) for k,v in gauge_config["external_fcst"].items()]) assert all([isinstance(c,str) for c in gauge_config["external_fcst"]]) - return main_config, gauge_configs + ensemble_configs = configs.get("ensemble_configs", []) + for ensemble_config in ensemble_configs: + #ensemble_config["model_names"] = [Path(p) for p in ensemble_config['model_names']] + ensemble_config["ensemble_name"] = Path(ensemble_config['ensemble_name']) + ensemble_config["ensemble_model"] = ensemble_config.get('ensemble_model') + + return main_config, gauge_configs, ensemble_configs def _prepare_logging_1() -> None: @@ -111,7 +117,7 @@ def parse_args() -> argparse.Namespace: def main(passed_args) -> None: """Main Function""" - main_config, gauge_configs = get_configs(passed_args=passed_args) + main_config, gauge_configs, ensemble_configs = get_configs(passed_args=passed_args) _prepare_logging_1() @@ -140,6 +146,9 @@ def main(passed_args) -> None: for gauge_config in gauge_configs: connector.handle_gauge(gauge_config) + for ensemble_config in ensemble_configs: + connector.handle_ensemble(ensemble_config) + logging.info("Finished %s with parameters %s ", sys.argv[0], sys.argv[1:]) diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py index 310e857f7034d01e2250513821f41e07f34c1e1e..a44a132ad4478acc44941b3eeb37844262684578 100644 --- a/src/utils/db_tools/db_tools.py +++ b/src/utils/db_tools/db_tools.py @@ -124,7 +124,7 @@ class OracleWaVoConnection: model = hp.load_model(gauge_config["model_folder"]) model.eval() - created = datetime.now(timezone) + created = datetime.now(timezone.utc) #get the metadata for the external forecasts (to access them later from the db) input_meta_data = self._ext_fcst_metadata(gauge_config) @@ -267,6 +267,70 @@ class OracleWaVoConnection: y,gauge_config,end_time, member, created ) + def handle_ensemble(self, ensemble_config: dict) -> None: + """ + Makes ensemble forecasts and inserts them into the database. + If no ensemble model is specified in the ensemble_config, the mean of the members is used. + Args: + ensemble_config (dict): The configuration for the ensemble, model_folder, ensemble_name, ensemble_model. + """ + created = datetime.now(timezone.utc) + stmt = select(Modell).where(Modell.modellname.in_(bindparam("model_names",expanding=True))) + df_models = pd.read_sql(sql=stmt,con=self.engine,params={"model_names": ensemble_config["model_names"]}) + + pseudo_model_names = [] + model_list = [] + for model_name in ensemble_config["model_names"]: + model_path = Path(df_models[df_models["modellname"] == model_name]["modelldatei"].values[0]) + model_list.append(model_path) + pseudo_model_names.append(model_path.name) + + stmt = select(PegelForecasts).where( + PegelForecasts.tstamp.in_(bindparam("times",expanding=True)), + PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True)), + ) + #self.times + # SensorData.sensor_name.in_(bindparam("model_columns", expanding=True)), + df_ens = pd.read_sql( + sql=stmt, + con=self.engine, + index_col="tstamp", + params={ + "times": self.times, + "pseudo_model_names": pseudo_model_names, + }, + ) + fake_config = {'gauge': df_ens.iloc[0,0], + 'model_folder' : ensemble_config['ensemble_name']} + + if ensemble_config["ensemble_model"] is not None: + base_model = hp.load_model(model_list[0]) + + + ensemble_folder = Path(ensemble_config["ensemble_model"]) + ensemble_yaml = hp.load_settings_model(ensemble_folder) + + from models.ensemble_models import WaVoLightningEnsemble + + checkpoint_path = next((ensemble_folder / "checkpoints").iterdir()) + temp_model_list = len(ensemble_yaml["model_path_list"])*[base_model] + model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=temp_model_list) + #model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=None) + + + + for cur_time in df_ens.index.unique(): + df_time_slice = df_ens.loc[cur_time] + for member in self.members: + if ensemble_config["ensemble_model"] is None: + df_member = df_time_slice[df_time_slice["member"] == member] + y = df_member.iloc[:,4:].mean() + y = torch.Tensor(y) + else: + y = None + + self.insert_forecast(y,fake_config,cur_time, member, created) + def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame: """ @@ -549,7 +613,7 @@ class OracleWaVoConnection: fcst_values = {f"h{i}": None for i in range(1, 49)} else: fcst_values = {f"h{i}": forecast[i - 1].item() for i in range(1, 49)} - self._export_zrxp(forecast, gauge_config, end_time, member, sensor_name, model_name, created) + self._export_zrxp(forecast, gauge_config, end_time, member, sensor_name, model_name) stmt = select(PegelForecasts).where( PegelForecasts.tstamp == bindparam("tstamp"), @@ -597,7 +661,7 @@ class OracleWaVoConnection: session.commit() - def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created): + def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name): if self.main_config.get("export_zrxp") and member == 0: try: