diff --git a/src/models/ensemble_models.py b/src/models/ensemble_models.py index b245c971932b9c7985e6fd6afb5b24597ad7ce0d..f83c6845b7b80c8044732a8e017ae90213d85122 100644 --- a/src/models/ensemble_models.py +++ b/src/models/ensemble_models.py @@ -128,7 +128,7 @@ class WaVoLightningEnsemble(pl.LightningModule): """ Since the data for this is normalized using the scaler from the first model in model_list all models should be trained on the same or at least similar data i think? """ # pylint: disable-next=unused-argument - def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate): + def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate,return_base=False): super().__init__() self.model_list = model_list @@ -139,7 +139,7 @@ class WaVoLightningEnsemble(pl.LightningModule): self.scaler = model_list[0].scaler#TODO scaler richtig aus yaml laden -> als parameter übergeben, wie beim normalen modell. assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx" self.model_architecture = 'ensemble' - + self.return_base = return_base #self.weighting = WeightingModelAttention( self.weighting = WeightingModel( @@ -154,11 +154,14 @@ class WaVoLightningEnsemble(pl.LightningModule): #TODO different insizes #TODO ModuleList?! - self.save_hyperparameters(ignore=['model_list','scaler']) + self.save_hyperparameters(ignore=['model_list','scaler','return_base']) self.save_hyperparameters({"scaler": pickle.dumps(self.scaler)}) for model in self.model_list: model.freeze() + def set_return_base(self,return_base): + self.return_base = return_base + def _common_step(self, batch, batch_idx,dataloader_idx=0): """ Computes the weighted average of the predictions of the models in the ensemble. @@ -184,7 +187,8 @@ class WaVoLightningEnsemble(pl.LightningModule): y_hats = torch.stack(y_hat_list,axis=1) w = self.weighting(x) y_hat = torch.sum((y_hats*w),axis=1) - + if self.return_base: + return y_hat, y_hats return y_hat diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py index a44a132ad4478acc44941b3eeb37844262684578..9dfddb7fd753cf65ba12b35109fd8bfbf6eac645 100644 --- a/src/utils/db_tools/db_tools.py +++ b/src/utils/db_tools/db_tools.py @@ -129,19 +129,11 @@ class OracleWaVoConnection: input_meta_data = self._ext_fcst_metadata(gauge_config) if len(self.times)> 1: - df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created) + df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created,series=True) + df_ext = self.load_input_ext_fcst(gauge_config["external_fcst"]) #load external forecasts - #load external forecasts - if self.main_config.get("fake_external"): - df_temp = None - else: - stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(gauge_config["external_fcst"]), - between(InputForecasts.tstamp, self.times[0], self.times[-1]), - InputForecasts.member.in_(self.members)) - - df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp") - dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp) + dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_ext,model.differencing) df_pred = pred_multiple_db(model, dataset) df_pred["created"] = created df_pred["sensor_name"] = gauge_config["gauge"] @@ -151,7 +143,6 @@ class OracleWaVoConnection: PegelForecasts.model==gauge_config["model_folder"].name, PegelForecasts.sensor_name==gauge_config["gauge"], PegelForecasts.tstamp.in_(df_pred["tstamp"]), - #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()), ) df_already_inserted = pd.read_sql(sql=stmt,con=self.engine) @@ -231,37 +222,12 @@ class OracleWaVoConnection: None """ try: - if member == -1: - df_input = df_base_input.fillna(0) - else: - # replace fake forecast with external forecast - df_input = df_base_input.copy() - for input_meta in input_meta_data: - if input_meta.ensemble_members == 1: - temp_member = 0 - elif input_meta.ensemble_members == 21: - temp_member = member - else: - temp_member = member #TODO was hier tun? - - fcst_data = df_temp[(df_temp["member"]== temp_member) & (df_temp["sensor_name"] == input_meta.sensor_name)] - if len(fcst_data) == 0: - raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}") - - # Replace missing values with forecasts. - # It is guaranteed that the only missing values are in the "future" (after end_time) - assert df_input[:-48].isna().sum().sum() == 0 - nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48) - nan_indices = nan_indices[nan_indices >= 0] - nan_indices2 = nan_indices + (144 - 48) - col_idx = df_input.columns.get_loc(input_meta.sensor_name) - df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices] - + df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_temp) except LookupError as e: logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]}) y = torch.Tensor(48).fill_(np.nan) else: - y = pred_single_db(model, df_input) + y = pred_single_db(model, df_input,differencing=model.differencing) finally: self.insert_forecast( y,gauge_config,end_time, member, created @@ -274,7 +240,10 @@ class OracleWaVoConnection: Args: ensemble_config (dict): The configuration for the ensemble, model_folder, ensemble_name, ensemble_model. """ + #TODO add logging here + #TODO zrxp exporte? created = datetime.now(timezone.utc) + # Collect the base models stmt = select(Modell).where(Modell.modellname.in_(bindparam("model_names",expanding=True))) df_models = pd.read_sql(sql=stmt,con=self.engine,params={"model_names": ensemble_config["model_names"]}) @@ -285,54 +254,113 @@ class OracleWaVoConnection: model_list.append(model_path) pseudo_model_names.append(model_path.name) - stmt = select(PegelForecasts).where( - PegelForecasts.tstamp.in_(bindparam("times",expanding=True)), - PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True)), - ) - #self.times - # SensorData.sensor_name.in_(bindparam("model_columns", expanding=True)), - df_ens = pd.read_sql( - sql=stmt, - con=self.engine, - index_col="tstamp", - params={ - "times": self.times, - "pseudo_model_names": pseudo_model_names, - }, - ) - fake_config = {'gauge': df_ens.iloc[0,0], - 'model_folder' : ensemble_config['ensemble_name']} + fake_config = self._generate_fake_config(df_models,ensemble_config) + input_meta_data = self._ext_fcst_metadata(fake_config) + df_ext = self.load_input_ext_fcst(fake_config["external_fcst"]) #load external forecasts + + #load ensemble model, maybe with dummys + if ensemble_config["mode"] != 'naiv': + model, ensemble_yaml = hp.load_ensemble_model(ensemble_config["ensemble_model"],model_list,dummy=ensemble_config["mode"] == 'only_ensemble') + model.set_return_base(ensemble_config["mode"] == 'save_both') + differencing = ensemble_yaml["differencing"] + - if ensemble_config["ensemble_model"] is not None: - base_model = hp.load_model(model_list[0]) + #save_both -> calculate the base models and the ensemble model and save both + #calc_both> calculate the base models and the ensemble model but only save the ensemble model + + #Warning: This behaves slightly differently from the other versions. + #Here the external forecast only overwrite their corresponding sensor values if they are nan. + #This means the model has more information for older timestamps. + if ensemble_config["mode"] == 'only_ensemble' or ensemble_config["mode"] == 'naiv': + stmt = select(PegelForecasts).where( + PegelForecasts.tstamp.in_(bindparam("times",expanding=True)), + PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True))) + df_ens = pd.read_sql(sql=stmt,con=self.engine,index_col="tstamp", + params={"times": self.times,"pseudo_model_names": pseudo_model_names}) + + #self.insert_forecast(y,fake_config,cur_time, member, created) + #TODO this could be much more efficient if i collected the forecasts i think + if ensemble_config["mode"] == 'naiv': # -> load the results of the base models from the database and take the naive mean + for cur_time in df_ens.index.unique(): + df_y_time_slice = df_ens.loc[cur_time] + for member in self.members: + df_y_member = df_y_time_slice[df_y_time_slice["member"] == member] + y = df_y_member.iloc[:,4:].mean() + y = torch.Tensor(y) + self.insert_forecast(y,fake_config,cur_time, member, created) + + if ensemble_config["mode"] == 'only_ensemble': # -> load the results of the base models from the database + for cur_time in df_ens.index.unique(): + df_y_time_slice = df_ens.loc[cur_time] + df_base_input = self.load_input_db(model.max_in_size,cur_time,fake_config,created) + for member in self.members: + df_y_member = df_y_time_slice[df_y_time_slice["member"] == member] + y_ens = df_y_member.iloc[:,4:].values + try: + x = self._get_x(df_base_input,member,input_meta_data,df_ext,differencing,model) + except LookupError as e: + logging.error(e.args[0],extra={"gauge":fake_config["gauge"]}) + y = torch.Tensor(48).fill_(np.nan) + else: + w = model.weighting(x) + y = torch.sum((torch.tensor(y_ens)*w),axis=1)[0] + finally: + self.insert_forecast(y,fake_config,cur_time, member, created) + + elif ensemble_config["mode"] == 'save_both' or ensemble_config["mode"] == 'calc_both': + #df_base_input_all = self.load_input_db(model.max_in_size,self.times[-1],fake_config,created) + + df_base_input = self.load_input_db(model.max_in_size, self.times[-1], fake_config,created,series=True) + ds = self._make_dataset(fake_config,df_base_input, model,input_meta_data,df_ext,ensemble_yaml["differencing"]) + #df_pred = pred_multiple_db(model, dataset) + data_loader = DataLoader(ds, batch_size=256,num_workers=1) + pred = hp.get_pred(model, data_loader,None) + + df_base_preds = [] + df_template = ds.get_whole_index() + if ensemble_config["mode"] == 'save_both': + pred, base_preds = pred + for j in range(base_preds.shape[1]): + df_temp = pd.concat([df_template, pd.DataFrame(base_preds[:,j,:],columns=[f"h{i}" for i in range(1,49)])], axis=1) + df_temp["model"] = model_list[j].name + df_base_preds.append(df_temp) + + df_temp = pd.concat([df_template, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1) + df_temp["model"] = fake_config["model_folder"].name + df_base_preds.append(df_temp) + df_pred = pd.concat(df_base_preds) + df_pred["created"] = created + df_pred["sensor_name"] = fake_config["gauge"] - ensemble_folder = Path(ensemble_config["ensemble_model"]) - ensemble_yaml = hp.load_settings_model(ensemble_folder) - from models.ensemble_models import WaVoLightningEnsemble + #pseudo_model_names + stmt =select(PegelForecasts.tstamp,PegelForecasts.member,PegelForecasts.model).where( + #PegelForecasts.model==fake_config["model_folder"].name, + #PegelForecasts.model.in_(pseudo_model_names + [fake_config["model_folder"].name]), + PegelForecasts.model.in_(df_pred["model"].unique()), + PegelForecasts.sensor_name==fake_config["gauge"], + PegelForecasts.tstamp.in_(df_pred["tstamp"].unique()), + ) + + df_already_inserted = pd.read_sql(sql=stmt,con=self.engine) + merged = df_pred.merge(df_already_inserted[['tstamp', 'member','model']],on=['tstamp', 'member','model'], how='left',indicator=True) + df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) + df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1) + #sqlalchemy.orm.exc.StaleDataError: UPDATE statement on table 'pegel_forecasts' expected to update 13860 row(s); 12600 were matched. + with Session(self.engine) as session: + forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist() + session.add_all(forecast_objs) + session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records')) + session.commit() - checkpoint_path = next((ensemble_folder / "checkpoints").iterdir()) - temp_model_list = len(ensemble_yaml["model_path_list"])*[base_model] - model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=temp_model_list) - #model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=None) - for cur_time in df_ens.index.unique(): - df_time_slice = df_ens.loc[cur_time] - for member in self.members: - if ensemble_config["ensemble_model"] is None: - df_member = df_time_slice[df_time_slice["member"] == member] - y = df_member.iloc[:,4:].mean() - y = torch.Tensor(y) - else: - y = None - - self.insert_forecast(y,fake_config,cur_time, member, created) - def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame: + + def load_input_db(self, in_size, end_time, gauge_config,created,series=False) -> pd.DataFrame: """ Loads input data from the database based on the current configuration. @@ -341,6 +369,7 @@ class OracleWaVoConnection: end_time (pd.Timestamp): The end time of the input data. gauge_config (dict): The configuration for the gauge, important here are the columns and external forecasts. created (pd.Timestamp): The timestamp of the start of the forecast creation. + series (bool): If True, the input data is loaded all at once, this doesn't work with the standard way, because the forecasts would be available to models that shouldn have them yet. Returns: pd.DataFrame: The loaded input data as a pandas DataFrame. @@ -349,7 +378,7 @@ class OracleWaVoConnection: columns=gauge_config["columns"] external_fcst=gauge_config["external_fcst"] - if len(self.times) > 1: + if len(self.times) > 1 and series: first_end_time = self.times[0] start_time = first_end_time - pd.Timedelta(in_size - 1, "hours") if self.main_config.get("fake_external"): @@ -452,6 +481,62 @@ class OracleWaVoConnection: return df_input + def load_input_ext_fcst(self,external_fsct = None) -> pd.DataFrame: + if self.main_config.get("fake_external"): + return None + + if external_fsct is None: + external_fsct = [] + + stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(external_fsct), + between(InputForecasts.tstamp, self.times[0], self.times[-1]), + InputForecasts.member.in_(self.members)) + + df_ext = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp") + return df_ext + + def fill_base_input(self,member,df_base_input,input_meta_data,df_ext) -> pd.DataFrame: + """This function fills the base input data with external forecasts depending on the current (available) member. + + Args: + member (_type_): member identifier (-1-21) + df_base_input (_type_): sensor data from the database + input_meta_data (_type_): metadata for the external forecasts + df_ext (_type_): external forecast data + + Returns: + pd.DataFrame: Input data with external forecasts inserted + """ + if member == -1: + df_input = df_base_input.fillna(0) + else: + # replace fake forecast with external forecast + df_input = df_base_input.copy() + for input_meta in input_meta_data: + if input_meta.ensemble_members == 1: + temp_member = 0 + elif input_meta.ensemble_members == 21: + temp_member = member + else: + temp_member = member #TODO was hier tun? + + fcst_data = df_ext[(df_ext["member"]== temp_member) & (df_ext["sensor_name"] == input_meta.sensor_name)] + if len(fcst_data) == 0: + raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}") + + # Replace missing values with forecasts. + # It is guaranteed that the only missing values are in the "future" (after end_time) + assert df_input[:-48].isna().sum().sum() == 0 + nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48) + nan_indices = nan_indices[nan_indices >= 0] + nan_indices2 = nan_indices + (144 - 48) + col_idx = df_input.columns.get_loc(input_meta.sensor_name) + df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices] + + return df_input + + + def add_zrxp_data(self, zrxp_folders: List[Path]) -> None: """Adds zrxp data to the database @@ -660,6 +745,32 @@ class OracleWaVoConnection: #try: session.commit() + def _generate_fake_config(self,df_models,ensemble_config): + model_id = df_models.iloc[0,0].item() + + query = ( + select(ModellSensor) + .where(ModellSensor.modell_id == model_id) + .order_by(ModellSensor.ix) + ) + df_ms = pd.read_sql(query,self.engine) + + fake_config = {'gauge': df_models["target_sensor_name"].values[0], + 'model_folder' : ensemble_config['ensemble_name'], + 'external_fcst' : list(df_ms.dropna()["sensor_name"]), #drop the rows without a value in vhs_gebiet, + 'columns' : list(df_ms["sensor_name"])} + return fake_config + + def _get_x(self,df_base_input,member,input_meta_data,df_ext,differencing,model): + df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_ext) + with warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", category=UserWarning) + if differencing == 1: + df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) + x = model.scaler.transform(df_input.values) + x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0) + return x + def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name): if self.main_config.get("export_zrxp") and member == 0: @@ -704,10 +815,10 @@ class OracleWaVoConnection: - def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext): + def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext,differencing): with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=UserWarning) - if model.differencing == 1: + if differencing == 1: df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) #TODO time embedding stuff df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns) @@ -876,20 +987,21 @@ class OracleWaVoConnection: @torch.no_grad() -def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor: +def pred_single_db(model, df_input: pd.DataFrame,differencing=True) -> torch.Tensor: """ Predicts the output for a single input data point. Args: model (LightningModule): The model to use for prediction. df_input (pandas.DataFrame): Input data point as a DataFrame. + differencing (bool): If True a column is inserted at the last place with differencing of the target value Returns: torch.Tensor: Predicted output as a tensor. """ with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=UserWarning) - if model.differencing == 1: + if differencing == 1: df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) #TODO time embedding stuff x = model.scaler.transform( diff --git a/src/utils/helpers.py b/src/utils/helpers.py index ce591b6212f15526f1ab9b144d23afc0bbd349d9..23c293da67985b78a4078cbf287cf7e92d1bd3d1 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader # from data_tools.data_module import WaVoDataModule from data_tools.data_module import WaVoDataModule from models.lightning_module import WaVoLightningModule +from models.ensemble_models import WaVoLightningEnsemble def load_settings_model(model_dir: Path) -> dict: @@ -90,11 +91,17 @@ def get_pred( if move_model: model.cuda() - y_pred = np.concatenate(pred) + if getattr(model, 'return_base', None): + pred, base_preds = pred[0] #TODO this function is not tested with more than one batch + return pred, base_preds - if y_true is not None: - y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred? - y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1])) + else: + + y_pred = np.concatenate(pred) + + if y_true is not None: + y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred? + y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1])) return y_pred @@ -211,3 +218,36 @@ def load_modules_rl( data_module.setup(stage="") return model_list, data_module + + +def load_ensemble_model(model_dir, model_dir_list,dummy=False): + """Loads a WaVoLightningEnsemble from the given folder, also need a list of the models that are part of the ensemble. + + Args: + model_dir : Path to the folder containing the ensemble model + model_dir_list : List of paths to the models that are part of the ensemble + dummy (bool, optional): If True only the first base_model in model_dir_list is actually loaded. Defaults to False. + + Returns: + _type_: model and ensemble_yaml with infos about that model + """ + model_dir = Path(model_dir) + ensemble_yaml = load_settings_model(model_dir) + + if dummy: + base_model = load_model(model_dir=model_dir_list[0]) + model_list = len(model_dir_list)*[base_model] + else: + model_list = [load_model(model_dir=Path(p)) for p in model_dir_list] + + for model in model_list: + model.eval() + + checkpoint_path = next((model_dir / "checkpoints").iterdir()) + model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=model_list) + + model.gauge_idx = model_list[0].gauge_idx #hacky + model.in_size = model.max_in_size + model.eval() + return model, ensemble_yaml +