diff --git a/src/train_db.py b/src/train_db.py index 7ad705823b59215e04a48a7c85ba56349b7e3498..8b199fe0717c8ea1d43c41275a1d34d6a65311d1 100644 --- a/src/train_db.py +++ b/src/train_db.py @@ -132,6 +132,7 @@ class DBObjective: # Save metrics to optuna model_path = str(Path(trainer.log_dir).resolve()) trial.set_user_attr("model_path", model_path) + trial.set_user_attr("external_fcst", data_module.external_fcst) for metric in ["hp/val_nse", "hp/val_mae", "hp/val_mae_flood"]: for i in [23, 47]: @@ -141,6 +142,7 @@ class DBObjective: if self.config["general"]["save_all_to_db"]: model_id = dbt.add_model(self.engine,self.config,trainer.log_dir) if model_id != -1: + dbt.add_model_sensor(self.engine,model_id,data_module.sensors,data_module.external_fcst) dbt.add_metrics(self.engine,model_id,my_callback.metrics,data_module.time_ranges) #dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value) @@ -248,11 +250,14 @@ def main(): #TODO mark the model somehow maybe? else: model_path = best_trial.user_attrs["model_path"] + external_fcst = best_trial.user_attrs["external_fcst"] hparams = hp.load_settings_model(Path(model_path)) time_ranges = {s: {a : pd.to_datetime(hparams[f"{s}_{a}"]) for a in ["start","end"]} for s in ["train","val","test"]} model_id = dbt.add_model(engine,config,model_path) if model_id != -1: + dbt.add_model_sensor(engine,model_id,hparams["scaler"].feature_names_in_,external_fcst) + reader = SummaryReader(model_path,pivot=True) df =reader.scalars loss_cols = df.columns[df.columns.str.contains("loss")] @@ -262,8 +267,6 @@ def main(): metrics[k] = torch.tensor(v) dbt.add_metrics(engine,model_id,metrics,time_ranges) - #dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value) - con.close() diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py index 16320552bdce546306d64ba88f3b190bff74d472..c4eb34a28bcf23bc042ad9716f4ab9e5eb57fe66 100644 --- a/src/utils/db_tools/db_tools.py +++ b/src/utils/db_tools/db_tools.py @@ -1014,6 +1014,28 @@ def add_model(engine,config,model_path,in_size=144) -> int: logging.info("Added model %s to database",model_obj) return model_obj.id +def add_model_sensor(engine,model_id,sensor_names,external_fcst) -> None: + with Session(engine) as session: + + + for i,col in enumerate(sensor_names): + if col != "d1": + if col in external_fcst: + vhs_gebiet = session.scalars(select(InputForecastsMeta.vhs_gebiet).where(InputForecastsMeta.sensor_name == col)).first() + if vhs_gebiet is None: + raise LookupError(f"Sensor {col} not found in INPUT_FORECASTS_META") + else: + vhs_gebiet = None + + model_sensor = ModellSensor( + modell_id = model_id, + sensor_name = col, + vhs_gebiet = vhs_gebiet, + ix = i + ) + session.add(model_sensor) + session.commit() + def add_metrics(engine,model_id,metrics,time_ranges) -> None: metric_list = [] for key,value in metrics.items():