Skip to content
Snippets Groups Projects
Commit 2a7023be authored by Michel Spils's avatar Michel Spils
Browse files

also import modelsensor when importing while training

parent ba56fd61
No related branches found
No related tags found
No related merge requests found
...@@ -132,6 +132,7 @@ class DBObjective: ...@@ -132,6 +132,7 @@ class DBObjective:
# Save metrics to optuna # Save metrics to optuna
model_path = str(Path(trainer.log_dir).resolve()) model_path = str(Path(trainer.log_dir).resolve())
trial.set_user_attr("model_path", model_path) trial.set_user_attr("model_path", model_path)
trial.set_user_attr("external_fcst", data_module.external_fcst)
for metric in ["hp/val_nse", "hp/val_mae", "hp/val_mae_flood"]: for metric in ["hp/val_nse", "hp/val_mae", "hp/val_mae_flood"]:
for i in [23, 47]: for i in [23, 47]:
...@@ -141,6 +142,7 @@ class DBObjective: ...@@ -141,6 +142,7 @@ class DBObjective:
if self.config["general"]["save_all_to_db"]: if self.config["general"]["save_all_to_db"]:
model_id = dbt.add_model(self.engine,self.config,trainer.log_dir) model_id = dbt.add_model(self.engine,self.config,trainer.log_dir)
if model_id != -1: if model_id != -1:
dbt.add_model_sensor(self.engine,model_id,data_module.sensors,data_module.external_fcst)
dbt.add_metrics(self.engine,model_id,my_callback.metrics,data_module.time_ranges) dbt.add_metrics(self.engine,model_id,my_callback.metrics,data_module.time_ranges)
#dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value) #dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value)
...@@ -248,11 +250,14 @@ def main(): ...@@ -248,11 +250,14 @@ def main():
#TODO mark the model somehow maybe? #TODO mark the model somehow maybe?
else: else:
model_path = best_trial.user_attrs["model_path"] model_path = best_trial.user_attrs["model_path"]
external_fcst = best_trial.user_attrs["external_fcst"]
hparams = hp.load_settings_model(Path(model_path)) hparams = hp.load_settings_model(Path(model_path))
time_ranges = {s: {a : pd.to_datetime(hparams[f"{s}_{a}"]) for a in ["start","end"]} for s in ["train","val","test"]} time_ranges = {s: {a : pd.to_datetime(hparams[f"{s}_{a}"]) for a in ["start","end"]} for s in ["train","val","test"]}
model_id = dbt.add_model(engine,config,model_path) model_id = dbt.add_model(engine,config,model_path)
if model_id != -1: if model_id != -1:
dbt.add_model_sensor(engine,model_id,hparams["scaler"].feature_names_in_,external_fcst)
reader = SummaryReader(model_path,pivot=True) reader = SummaryReader(model_path,pivot=True)
df =reader.scalars df =reader.scalars
loss_cols = df.columns[df.columns.str.contains("loss")] loss_cols = df.columns[df.columns.str.contains("loss")]
...@@ -262,8 +267,6 @@ def main(): ...@@ -262,8 +267,6 @@ def main():
metrics[k] = torch.tensor(v) metrics[k] = torch.tensor(v)
dbt.add_metrics(engine,model_id,metrics,time_ranges) dbt.add_metrics(engine,model_id,metrics,time_ranges)
#dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value)
con.close() con.close()
......
...@@ -1014,6 +1014,28 @@ def add_model(engine,config,model_path,in_size=144) -> int: ...@@ -1014,6 +1014,28 @@ def add_model(engine,config,model_path,in_size=144) -> int:
logging.info("Added model %s to database",model_obj) logging.info("Added model %s to database",model_obj)
return model_obj.id return model_obj.id
def add_model_sensor(engine,model_id,sensor_names,external_fcst) -> None:
with Session(engine) as session:
for i,col in enumerate(sensor_names):
if col != "d1":
if col in external_fcst:
vhs_gebiet = session.scalars(select(InputForecastsMeta.vhs_gebiet).where(InputForecastsMeta.sensor_name == col)).first()
if vhs_gebiet is None:
raise LookupError(f"Sensor {col} not found in INPUT_FORECASTS_META")
else:
vhs_gebiet = None
model_sensor = ModellSensor(
modell_id = model_id,
sensor_name = col,
vhs_gebiet = vhs_gebiet,
ix = i
)
session.add(model_sensor)
session.commit()
def add_metrics(engine,model_id,metrics,time_ranges) -> None: def add_metrics(engine,model_id,metrics,time_ranges) -> None:
metric_list = [] metric_list = []
for key,value in metrics.items(): for key,value in metrics.items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment