diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py index b64a16cba96248c3d0671aa1bb493869a25afd30..2e8b0783aa96d613fdd51b3445288297ba5830ba 100644 --- a/src/utils/db_tools/db_tools.py +++ b/src/utils/db_tools/db_tools.py @@ -22,6 +22,7 @@ from sqlalchemy.exc import IntegrityError import torch import utils.helpers as hp +from utils.utility import scale_standard from utils.db_tools.orm_classes import ( ModellSensor, SensorData, @@ -40,6 +41,59 @@ from torch.utils.data import DataLoader,Dataset +class DBDataSet(Dataset): + def __init__(self, df_scaled,in_size,ext_meta,df_ext=None) -> None: + self.df_scaled = df_scaled + self.in_size = in_size + self.df_ext = df_ext + self.ext_meta = ext_meta #TODO rename + self.y = torch.Tensor(([0])) # dummy + + + if df_ext is None: + self.max_ens_size = 1 + self.common_indexes = df_scaled.index[self.in_size:] #TODO check if this is correct + else: + self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta]) + self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta] + self.common_indexes = reduce(np.intersect1d, self.unique_indexes) + + def __len__(self) -> int: + return len(self.common_indexes)*self.max_ens_size + + def get_whole_index(self): + if self.df_ext is None: + rows = [(self.common_indexes[idx // self.max_ens_size],-1) for idx in range(len(self))] + else: + rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))] + return pd.DataFrame(rows,columns=["tstamp","member"]) + + + def __getitem__(self, idx): + if idx >= len(self): + raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}") + + tstamp = self.common_indexes[idx // self.max_ens_size] + ens_num = idx % self.max_ens_size + + x = self.df_scaled[:tstamp][-self.in_size:].astype("float32") + + if self.df_ext is not None: + for meta in self.ext_meta: + if meta.ensemble_members == 1: + #ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name] + ext_values = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values + else: + #TODO hier crasht es wenn für einen zeitpunkt entweder ensemble oder hauptläufe gibt + ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values + + + col_idx = x.columns.get_loc(meta.sensor_name) + x.iloc[-48:,col_idx] = ext_values.astype("float32") + + + #y = self.y[idx+self.in_size:idx+self.in_size+self.out_size] + return x.values, self.y class OracleWaVoConnection: """Class for writing/reading from the Wavo Database.""" @@ -87,12 +141,29 @@ class OracleWaVoConnection: df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp") - dataset = self._make_dataset(df_base_input, model,input_meta_data,df_temp) - #else: - + dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp) + df_pred = pred_multiple_db(model, dataset) + df_pred["created"] = created + df_pred["sensor_name"] = gauge_config["gauge"] + df_pred["model"] = gauge_config["model_folder"].name + + stmt =select(PegelForecasts.tstamp,PegelForecasts.member).where( + PegelForecasts.model==gauge_config["model_folder"].name, + PegelForecasts.sensor_name==gauge_config["gauge"], + PegelForecasts.tstamp.in_(df_pred["tstamp"]), + #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()), + ) + + df_already_inserted = pd.read_sql(sql=stmt,con=self.engine) + merged = df_pred.merge(df_already_inserted[['tstamp', 'member']],on=['tstamp', 'member'], how='left',indicator=True) + df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) + df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1) - pred_multiple_db(model, df_base_input) - print("a") + with Session(self.engine) as session: + forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist() + session.add_all(forecast_objs) + session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records')) + session.commit() else: #TODO this is weird, cause it's just one at the moment @@ -528,34 +599,37 @@ class OracleWaVoConnection: def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created): if self.main_config.get("export_zrxp") and member == 0: - with Session(self.engine) as session: - stmt = select(Modell).where(Modell.modelldatei == str(gauge_config["model_folder"].resolve())) - model_obj = session.scalar(stmt) - if model_obj is None: + try: + df_models = pd.read_sql(sql=select(Modell.modellname, Modell.modelldatei), con=self.engine) + df_models["modelldatei"] = df_models["modelldatei"].apply(Path) + df_models["modelldatei"] = df_models["modelldatei"].apply(lambda x: x.resolve()) + real_model_name = df_models[df_models["modelldatei"]== gauge_config["model_folder"].resolve()]["modellname"].values[0] + except IndexError: logging.error("Model %s not found in database, skipping zrxp export",gauge_config["model_folder"]) return + finally: - target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx" - #2023 11 19 03 - df_zrxp = pd.DataFrame(forecast,columns=["value"]) - df_zrxp["timestamp"] = end_time - df_zrxp["forecast"] = pd.date_range(start=end_time,periods=49,freq="1h")[1:] - df_zrxp["member"] = member - df_zrxp = df_zrxp[['timestamp','forecast','member','value']] + target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx" + #2023 11 19 03 + df_zrxp = pd.DataFrame(forecast,columns=["value"]) + df_zrxp["timestamp"] = end_time + df_zrxp["forecast"] = pd.date_range(start=end_time,periods=49,freq="1h")[1:] + df_zrxp["member"] = member + df_zrxp = df_zrxp[['timestamp','forecast','member','value']] - with open(target_file , 'w', encoding="utf-8") as file: - file.write('#REXCHANGEWISKI.' + model_obj.modellname + '.W.KNN|*|\n') - file.write('#RINVAL-777|*|\n') - file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n') + with open(target_file , 'w', encoding="utf-8") as file: + file.write('#REXCHANGEWISKI.' + real_model_name + '.W.KNN|*|\n') + file.write('#RINVAL-777|*|\n') + file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n') - df_zrxp.to_csv(path_or_buf = target_file, - header = False, - index=False, - mode='a', - sep = ' ', - date_format = '%Y%m%d%H%M') + df_zrxp.to_csv(path_or_buf = target_file, + header = False, + index=False, + mode='a', + sep = ' ', + date_format = '%Y%m%d%H%M') def _ext_fcst_metadata(self, gauge_config): @@ -568,61 +642,32 @@ class OracleWaVoConnection: - def _make_dataset(self, df_input, model,input_meta_data,df_ext_forecasts): + def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext): with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=UserWarning) if model.differencing == 1: df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) #TODO time embedding stuff df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns) - #x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0) #TODO skalierung df_ext_forecast + trans_cols = [f"h{i}" for i in range(1,49)] + for col in gauge_config["external_fcst"]: + idx = gauge_config["columns"].index(col) + def scale_curry(x): + return scale_standard(x,model.scaler.mean_[idx],model.scaler.scale_[idx]) #TODO maybe people want to use different scalers - class DBDataSet(Dataset): - def __init__(self, df_scaled,in_size,ext_meta,df_ext) -> None: - self.df_scaled = df_scaled - self.in_size = in_size - self.df_ext = df_ext - self.ext_meta = ext_meta #TODO rename + df_ext.loc[df_ext["sensor_name"]==col,trans_cols] = df_ext.loc[df_ext["sensor_name"]==col,trans_cols].apply(scale_curry) - self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta]) - - self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta] - self.common_indexes = reduce(np.intersect1d, self.unique_indexes) - - def __len__(self) -> int: - return len(self.common_indexes)*self.max_ens_size - - def get_whole_index(self): - rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))] - - return pd.DataFrame(rows,columns=["tstamp","member"]) - - - def __getitem__(self, idx): - tstamp = self.common_indexes[idx // self.max_ens_size] - ens_num = idx % self.max_ens_size + print("a") + #scale_standard - x = self.df_scaled[:tstamp][-self.in_size:] + #with Session(self.engine) as session: + # select() - for meta in self.ext_meta: - print(meta) - if meta.ensemble_members == 1: - #ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name] - ext_values = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values - else: - ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values + ds = DBDataSet(df_scaled,model.in_size,input_meta_data,df_ext) + return ds - - col_idx = x.columns.get_loc(meta.sensor_name) - x.iloc[-48:,col_idx] = ext_values.astype("float32") - - if idx >= len(self): - raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}") - #y = self.y[idx+self.in_size:idx+self.in_size+self.out_size] - return x.values, None #TODO auch y ? - def maybe_update_tables(self) -> None: """ @@ -806,7 +851,7 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor: @torch.no_grad() -def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor: +def pred_multiple_db(model, ds: Dataset) -> torch.Tensor: """ Predicts the output for a single input data point. @@ -817,23 +862,13 @@ def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor: Returns: torch.Tensor: Predicted output as a tensor. """ - with warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=UserWarning) - if model.differencing == 1: - df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) - #TODO time embedding stuff - x = model.scaler.transform( - df_input.values - ) # TODO values/df column names wrong/old - x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0) - # y = ut.inv_standard(model(x), model.scaler.mean_[model.target_idx], model.scaler.scale_[model.target_idx]) + data_loader = DataLoader(ds, batch_size=256,num_workers=1) + pred = hp.get_pred(model, data_loader,None) + df_pred = ds.get_whole_index() + df_pred = pd.concat([df_pred, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1) - data_loader = DataLoader(dataset, batch_size=256,num_workers=8) - hp.get_pred(model, x,None) - y = model.predict_step((x, None), 0) - - return y[0] + return df_pred def get_merge_forecast_query(out_size: int = 48) -> str: diff --git a/src/utils/utility.py b/src/utils/utility.py index 71e2ea1123674cbde055bdeb71b42f276c94dd21..d1039da2f7795545152aec752f43029f35cc8912 100644 --- a/src/utils/utility.py +++ b/src/utils/utility.py @@ -91,6 +91,24 @@ def inv_minmax(x, xmin: float, xmax: float): return x * (xmax-xmin) + xmin +def scale_standard(x, mean: float, std: float): + """Standard scaling, also unpacks torch tensors if needed + + Args: + x (np.array,torch.Tensor or similar): array to inverse scaling + mean (float): mean value of a (training) dataset + std (float): standard deviation value of a (training) dataset + + Returns: + (np.array,torch.Tensor or similar): minmax scaled values + """ + if isinstance(x,torch.Tensor) and isinstance(mean, torch.Tensor) and isinstance(std, torch.Tensor): + return (x - mean) / std + if isinstance(mean, torch.Tensor): + mean = mean.item() + if isinstance(std, torch.Tensor): + std = std.item() + return (x - mean) / std def inv_standard(x, mean: float, std: float): """Inverse standard scaling, also unpacks torch tensors if needed