Skip to content
Snippets Groups Projects
Commit e50657c6 authored by Michel Spils's avatar Michel Spils
Browse files

long predictions

parent 2e853aad
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ from sqlalchemy.exc import IntegrityError ...@@ -22,6 +22,7 @@ from sqlalchemy.exc import IntegrityError
import torch import torch
import utils.helpers as hp import utils.helpers as hp
from utils.utility import scale_standard
from utils.db_tools.orm_classes import ( from utils.db_tools.orm_classes import (
ModellSensor, ModellSensor,
SensorData, SensorData,
...@@ -40,6 +41,59 @@ from torch.utils.data import DataLoader,Dataset ...@@ -40,6 +41,59 @@ from torch.utils.data import DataLoader,Dataset
class DBDataSet(Dataset):
def __init__(self, df_scaled,in_size,ext_meta,df_ext=None) -> None:
self.df_scaled = df_scaled
self.in_size = in_size
self.df_ext = df_ext
self.ext_meta = ext_meta #TODO rename
self.y = torch.Tensor(([0])) # dummy
if df_ext is None:
self.max_ens_size = 1
self.common_indexes = df_scaled.index[self.in_size:] #TODO check if this is correct
else:
self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta])
self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta]
self.common_indexes = reduce(np.intersect1d, self.unique_indexes)
def __len__(self) -> int:
return len(self.common_indexes)*self.max_ens_size
def get_whole_index(self):
if self.df_ext is None:
rows = [(self.common_indexes[idx // self.max_ens_size],-1) for idx in range(len(self))]
else:
rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))]
return pd.DataFrame(rows,columns=["tstamp","member"])
def __getitem__(self, idx):
if idx >= len(self):
raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}")
tstamp = self.common_indexes[idx // self.max_ens_size]
ens_num = idx % self.max_ens_size
x = self.df_scaled[:tstamp][-self.in_size:].astype("float32")
if self.df_ext is not None:
for meta in self.ext_meta:
if meta.ensemble_members == 1:
#ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name]
ext_values = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values
else:
#TODO hier crasht es wenn für einen zeitpunkt entweder ensemble oder hauptläufe gibt
ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values
col_idx = x.columns.get_loc(meta.sensor_name)
x.iloc[-48:,col_idx] = ext_values.astype("float32")
#y = self.y[idx+self.in_size:idx+self.in_size+self.out_size]
return x.values, self.y
class OracleWaVoConnection: class OracleWaVoConnection:
"""Class for writing/reading from the Wavo Database.""" """Class for writing/reading from the Wavo Database."""
...@@ -87,12 +141,29 @@ class OracleWaVoConnection: ...@@ -87,12 +141,29 @@ class OracleWaVoConnection:
df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp") df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
dataset = self._make_dataset(df_base_input, model,input_meta_data,df_temp) dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp)
#else: df_pred = pred_multiple_db(model, dataset)
df_pred["created"] = created
df_pred["sensor_name"] = gauge_config["gauge"]
df_pred["model"] = gauge_config["model_folder"].name
stmt =select(PegelForecasts.tstamp,PegelForecasts.member).where(
PegelForecasts.model==gauge_config["model_folder"].name,
PegelForecasts.sensor_name==gauge_config["gauge"],
PegelForecasts.tstamp.in_(df_pred["tstamp"]),
#between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()),
)
df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
merged = df_pred.merge(df_already_inserted[['tstamp', 'member']],on=['tstamp', 'member'], how='left',indicator=True)
df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1)
df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1)
pred_multiple_db(model, df_base_input) with Session(self.engine) as session:
print("a") forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist()
session.add_all(forecast_objs)
session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records'))
session.commit()
else: else:
#TODO this is weird, cause it's just one at the moment #TODO this is weird, cause it's just one at the moment
...@@ -528,13 +599,16 @@ class OracleWaVoConnection: ...@@ -528,13 +599,16 @@ class OracleWaVoConnection:
def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created): def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created):
if self.main_config.get("export_zrxp") and member == 0: if self.main_config.get("export_zrxp") and member == 0:
with Session(self.engine) as session:
stmt = select(Modell).where(Modell.modelldatei == str(gauge_config["model_folder"].resolve()))
model_obj = session.scalar(stmt)
if model_obj is None: try:
df_models = pd.read_sql(sql=select(Modell.modellname, Modell.modelldatei), con=self.engine)
df_models["modelldatei"] = df_models["modelldatei"].apply(Path)
df_models["modelldatei"] = df_models["modelldatei"].apply(lambda x: x.resolve())
real_model_name = df_models[df_models["modelldatei"]== gauge_config["model_folder"].resolve()]["modellname"].values[0]
except IndexError:
logging.error("Model %s not found in database, skipping zrxp export",gauge_config["model_folder"]) logging.error("Model %s not found in database, skipping zrxp export",gauge_config["model_folder"])
return return
finally:
target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx" target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx"
#2023 11 19 03 #2023 11 19 03
...@@ -546,7 +620,7 @@ class OracleWaVoConnection: ...@@ -546,7 +620,7 @@ class OracleWaVoConnection:
with open(target_file , 'w', encoding="utf-8") as file: with open(target_file , 'w', encoding="utf-8") as file:
file.write('#REXCHANGEWISKI.' + model_obj.modellname + '.W.KNN|*|\n') file.write('#REXCHANGEWISKI.' + real_model_name + '.W.KNN|*|\n')
file.write('#RINVAL-777|*|\n') file.write('#RINVAL-777|*|\n')
file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n') file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n')
...@@ -568,60 +642,31 @@ class OracleWaVoConnection: ...@@ -568,60 +642,31 @@ class OracleWaVoConnection:
def _make_dataset(self, df_input, model,input_meta_data,df_ext_forecasts): def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=UserWarning) warnings.filterwarnings(action="ignore", category=UserWarning)
if model.differencing == 1: if model.differencing == 1:
df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
#TODO time embedding stuff #TODO time embedding stuff
df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns) df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns)
#x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
#TODO skalierung df_ext_forecast #TODO skalierung df_ext_forecast
trans_cols = [f"h{i}" for i in range(1,49)]
for col in gauge_config["external_fcst"]:
idx = gauge_config["columns"].index(col)
def scale_curry(x):
return scale_standard(x,model.scaler.mean_[idx],model.scaler.scale_[idx]) #TODO maybe people want to use different scalers
class DBDataSet(Dataset): df_ext.loc[df_ext["sensor_name"]==col,trans_cols] = df_ext.loc[df_ext["sensor_name"]==col,trans_cols].apply(scale_curry)
def __init__(self, df_scaled,in_size,ext_meta,df_ext) -> None:
self.df_scaled = df_scaled
self.in_size = in_size
self.df_ext = df_ext
self.ext_meta = ext_meta #TODO rename
self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta])
self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta] print("a")
self.common_indexes = reduce(np.intersect1d, self.unique_indexes) #scale_standard
def __len__(self) -> int:
return len(self.common_indexes)*self.max_ens_size
def get_whole_index(self):
rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))]
return pd.DataFrame(rows,columns=["tstamp","member"])
def __getitem__(self, idx):
tstamp = self.common_indexes[idx // self.max_ens_size]
ens_num = idx % self.max_ens_size
x = self.df_scaled[:tstamp][-self.in_size:]
for meta in self.ext_meta:
print(meta)
if meta.ensemble_members == 1:
#ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name]
ext_values = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values
else:
ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values
col_idx = x.columns.get_loc(meta.sensor_name) #with Session(self.engine) as session:
x.iloc[-48:,col_idx] = ext_values.astype("float32") # select()
if idx >= len(self): ds = DBDataSet(df_scaled,model.in_size,input_meta_data,df_ext)
raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}") return ds
#y = self.y[idx+self.in_size:idx+self.in_size+self.out_size]
return x.values, None #TODO auch y ?
def maybe_update_tables(self) -> None: def maybe_update_tables(self) -> None:
...@@ -806,7 +851,7 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor: ...@@ -806,7 +851,7 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor:
@torch.no_grad() @torch.no_grad()
def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor: def pred_multiple_db(model, ds: Dataset) -> torch.Tensor:
""" """
Predicts the output for a single input data point. Predicts the output for a single input data point.
...@@ -817,23 +862,13 @@ def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor: ...@@ -817,23 +862,13 @@ def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor:
Returns: Returns:
torch.Tensor: Predicted output as a tensor. torch.Tensor: Predicted output as a tensor.
""" """
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=UserWarning)
if model.differencing == 1:
df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
#TODO time embedding stuff
x = model.scaler.transform(
df_input.values
) # TODO values/df column names wrong/old
x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
# y = ut.inv_standard(model(x), model.scaler.mean_[model.target_idx], model.scaler.scale_[model.target_idx]) data_loader = DataLoader(ds, batch_size=256,num_workers=1)
pred = hp.get_pred(model, data_loader,None)
df_pred = ds.get_whole_index()
df_pred = pd.concat([df_pred, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1)
data_loader = DataLoader(dataset, batch_size=256,num_workers=8) return df_pred
hp.get_pred(model, x,None)
y = model.predict_step((x, None), 0)
return y[0]
def get_merge_forecast_query(out_size: int = 48) -> str: def get_merge_forecast_query(out_size: int = 48) -> str:
......
...@@ -91,6 +91,24 @@ def inv_minmax(x, xmin: float, xmax: float): ...@@ -91,6 +91,24 @@ def inv_minmax(x, xmin: float, xmax: float):
return x * (xmax-xmin) + xmin return x * (xmax-xmin) + xmin
def scale_standard(x, mean: float, std: float):
"""Standard scaling, also unpacks torch tensors if needed
Args:
x (np.array,torch.Tensor or similar): array to inverse scaling
mean (float): mean value of a (training) dataset
std (float): standard deviation value of a (training) dataset
Returns:
(np.array,torch.Tensor or similar): minmax scaled values
"""
if isinstance(x,torch.Tensor) and isinstance(mean, torch.Tensor) and isinstance(std, torch.Tensor):
return (x - mean) / std
if isinstance(mean, torch.Tensor):
mean = mean.item()
if isinstance(std, torch.Tensor):
std = std.item()
return (x - mean) / std
def inv_standard(x, mean: float, std: float): def inv_standard(x, mean: float, std: float):
"""Inverse standard scaling, also unpacks torch tensors if needed """Inverse standard scaling, also unpacks torch tensors if needed
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment