diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py
index b64a16cba96248c3d0671aa1bb493869a25afd30..2e8b0783aa96d613fdd51b3445288297ba5830ba 100644
--- a/src/utils/db_tools/db_tools.py
+++ b/src/utils/db_tools/db_tools.py
@@ -22,6 +22,7 @@ from sqlalchemy.exc import IntegrityError
 import torch
 
 import utils.helpers as hp
+from utils.utility import scale_standard
 from utils.db_tools.orm_classes import (
     ModellSensor,
     SensorData,
@@ -40,6 +41,59 @@ from torch.utils.data import DataLoader,Dataset
 
 
     
+class DBDataSet(Dataset):
+    def __init__(self, df_scaled,in_size,ext_meta,df_ext=None) -> None:
+        self.df_scaled = df_scaled
+        self.in_size = in_size
+        self.df_ext = df_ext
+        self.ext_meta = ext_meta #TODO rename
+        self.y = torch.Tensor(([0])) # dummy
+
+        
+        if df_ext is None:
+            self.max_ens_size = 1
+            self.common_indexes = df_scaled.index[self.in_size:] #TODO check if this is correct
+        else:
+            self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta])
+            self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta]
+            self.common_indexes = reduce(np.intersect1d, self.unique_indexes)
+
+    def __len__(self) -> int:
+        return len(self.common_indexes)*self.max_ens_size
+
+    def get_whole_index(self):
+        if self.df_ext is None:
+            rows = [(self.common_indexes[idx // self.max_ens_size],-1) for idx in range(len(self))]
+        else:
+            rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))]
+        return pd.DataFrame(rows,columns=["tstamp","member"])
+            
+        
+    def __getitem__(self, idx):
+        if idx >= len(self):
+            raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}")
+        
+        tstamp = self.common_indexes[idx // self.max_ens_size]
+        ens_num = idx % self.max_ens_size
+
+        x = self.df_scaled[:tstamp][-self.in_size:].astype("float32")
+
+        if self.df_ext is not None:
+            for meta in self.ext_meta:
+                if meta.ensemble_members == 1:
+                    #ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name]
+                    ext_values  = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values
+                else:
+                    #TODO hier crasht es wenn für einen zeitpunkt entweder ensemble oder hauptläufe gibt 
+                    ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values
+
+                
+                col_idx = x.columns.get_loc(meta.sensor_name)
+                x.iloc[-48:,col_idx] = ext_values.astype("float32")
+
+
+        #y = self.y[idx+self.in_size:idx+self.in_size+self.out_size]
+        return x.values, self.y
 
 class OracleWaVoConnection:
     """Class for writing/reading from the Wavo Database."""
@@ -87,12 +141,29 @@ class OracleWaVoConnection:
 
                 df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
                     
-            dataset = self._make_dataset(df_base_input, model,input_meta_data,df_temp)
-            #else:
-
+            dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp)
+            df_pred = pred_multiple_db(model, dataset)
+            df_pred["created"] = created
+            df_pred["sensor_name"] = gauge_config["gauge"]
+            df_pred["model"] = gauge_config["model_folder"].name
+
+            stmt =select(PegelForecasts.tstamp,PegelForecasts.member).where(
+                PegelForecasts.model==gauge_config["model_folder"].name,
+                PegelForecasts.sensor_name==gauge_config["gauge"],
+                PegelForecasts.tstamp.in_(df_pred["tstamp"]),
+                #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()),
+                )
+                                                                                                                            
+            df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
+            merged = df_pred.merge(df_already_inserted[['tstamp', 'member']],on=['tstamp', 'member'], how='left',indicator=True)
+            df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1)
+            df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1)
 
-            pred_multiple_db(model, df_base_input)
-            print("a")
+            with Session(self.engine) as session:
+                forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist()
+                session.add_all(forecast_objs)
+                session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records'))
+                session.commit()
     
         else:
             #TODO this is weird, cause it's just one at the moment
@@ -528,34 +599,37 @@ class OracleWaVoConnection:
 
     def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created):
         if self.main_config.get("export_zrxp") and member == 0:
-            with Session(self.engine) as session:
-                stmt = select(Modell).where(Modell.modelldatei == str(gauge_config["model_folder"].resolve())) 
-                model_obj = session.scalar(stmt)
 
-            if model_obj is None:
+            try:
+                df_models = pd.read_sql(sql=select(Modell.modellname, Modell.modelldatei), con=self.engine)
+                df_models["modelldatei"] = df_models["modelldatei"].apply(Path)
+                df_models["modelldatei"] = df_models["modelldatei"].apply(lambda x: x.resolve())
+                real_model_name = df_models[df_models["modelldatei"]== gauge_config["model_folder"].resolve()]["modellname"].values[0]
+            except IndexError:
                 logging.error("Model %s not found in database, skipping zrxp export",gauge_config["model_folder"])
                 return
+            finally:
 
-            target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx"
-            #2023 11 19 03
-            df_zrxp = pd.DataFrame(forecast,columns=["value"])
-            df_zrxp["timestamp"] = end_time
-            df_zrxp["forecast"] = pd.date_range(start=end_time,periods=49,freq="1h")[1:]
-            df_zrxp["member"] = member        
-            df_zrxp = df_zrxp[['timestamp','forecast','member','value']]
+                target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx"
+                #2023 11 19 03
+                df_zrxp = pd.DataFrame(forecast,columns=["value"])
+                df_zrxp["timestamp"] = end_time
+                df_zrxp["forecast"] = pd.date_range(start=end_time,periods=49,freq="1h")[1:]
+                df_zrxp["member"] = member        
+                df_zrxp = df_zrxp[['timestamp','forecast','member','value']]
 
 
-            with open(target_file , 'w', encoding="utf-8") as file:
-                file.write('#REXCHANGEWISKI.' + model_obj.modellname + '.W.KNN|*|\n')
-                file.write('#RINVAL-777|*|\n')
-                file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n')
+                with open(target_file , 'w', encoding="utf-8") as file:
+                    file.write('#REXCHANGEWISKI.' + real_model_name + '.W.KNN|*|\n')
+                    file.write('#RINVAL-777|*|\n')
+                    file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n')
 
-            df_zrxp.to_csv(path_or_buf = target_file,
-                        header = False,
-                        index=False,
-                        mode='a',
-                        sep = ' ',
-                        date_format = '%Y%m%d%H%M')
+                df_zrxp.to_csv(path_or_buf = target_file,
+                            header = False,
+                            index=False,
+                            mode='a',
+                            sep = ' ',
+                            date_format = '%Y%m%d%H%M')
 
 
     def _ext_fcst_metadata(self, gauge_config):
@@ -568,61 +642,32 @@ class OracleWaVoConnection:
 
 
 
-    def _make_dataset(self, df_input, model,input_meta_data,df_ext_forecasts):
+    def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext):
         with warnings.catch_warnings():
             warnings.filterwarnings(action="ignore", category=UserWarning)
             if model.differencing == 1:
                 df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
             #TODO time embedding stuff
             df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns)
-        #x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
         #TODO skalierung df_ext_forecast
+        trans_cols = [f"h{i}" for i in range(1,49)]
+        for col in gauge_config["external_fcst"]:
+            idx = gauge_config["columns"].index(col)
+            def scale_curry(x):
+                return scale_standard(x,model.scaler.mean_[idx],model.scaler.scale_[idx]) #TODO maybe people want to use different scalers
 
-        class DBDataSet(Dataset):
-            def __init__(self, df_scaled,in_size,ext_meta,df_ext) -> None:
-                self.df_scaled = df_scaled
-                self.in_size = in_size
-                self.df_ext = df_ext
-                self.ext_meta = ext_meta #TODO rename
+            df_ext.loc[df_ext["sensor_name"]==col,trans_cols] = df_ext.loc[df_ext["sensor_name"]==col,trans_cols].apply(scale_curry)
 
-                self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta])
-                
 
-                self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta]
-                self.common_indexes = reduce(np.intersect1d, self.unique_indexes)
-
-            def __len__(self) -> int:
-                return len(self.common_indexes)*self.max_ens_size
-
-            def get_whole_index(self):
-                rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))]
-                                   
-                return pd.DataFrame(rows,columns=["tstamp","member"])
-                    
-                
-            def __getitem__(self, idx):
-                tstamp = self.common_indexes[idx // self.max_ens_size]
-                ens_num = idx % self.max_ens_size
+            print("a")
+            #scale_standard
 
-                x = self.df_scaled[:tstamp][-self.in_size:]
+        #with Session(self.engine) as session:
+        #    select()
 
-                for meta in self.ext_meta:
-                    print(meta)
-                    if meta.ensemble_members == 1:
-                        #ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name]
-                        ext_values  = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values
-                    else:
-                        ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values
+        ds = DBDataSet(df_scaled,model.in_size,input_meta_data,df_ext)
+        return ds
 
-                    
-                    col_idx = x.columns.get_loc(meta.sensor_name)
-                    x.iloc[-48:,col_idx] = ext_values.astype("float32")
-
-                if idx >= len(self):
-                    raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}")
-                #y = self.y[idx+self.in_size:idx+self.in_size+self.out_size]
-                return x.values, None #TODO auch y ? 
-    
         
     def maybe_update_tables(self) -> None:
         """
@@ -806,7 +851,7 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor:
 
     
 @torch.no_grad()
-def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor:
+def pred_multiple_db(model, ds: Dataset) -> torch.Tensor:
     """
     Predicts the output for a single input data point.
 
@@ -817,23 +862,13 @@ def pred_multiple_db(model, df_input: pd.DataFrame) -> torch.Tensor:
     Returns:
         torch.Tensor: Predicted output as a tensor.
     """
-    with warnings.catch_warnings():
-        warnings.filterwarnings(action="ignore", category=UserWarning)
-        if model.differencing == 1:
-            df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
-        #TODO time embedding stuff
-        x = model.scaler.transform(
-            df_input.values
-        )  # TODO values/df column names wrong/old
-    x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
 
-    # y = ut.inv_standard(model(x), model.scaler.mean_[model.target_idx], model.scaler.scale_[model.target_idx])
+    data_loader = DataLoader(ds, batch_size=256,num_workers=1)
+    pred = hp.get_pred(model, data_loader,None)
+    df_pred = ds.get_whole_index()
+    df_pred = pd.concat([df_pred, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1)
 
-    data_loader = DataLoader(dataset, batch_size=256,num_workers=8)
-    hp.get_pred(model, x,None)
-    y = model.predict_step((x, None), 0)
-
-    return y[0]
+    return df_pred
 
 
 def get_merge_forecast_query(out_size: int = 48) -> str:
diff --git a/src/utils/utility.py b/src/utils/utility.py
index 71e2ea1123674cbde055bdeb71b42f276c94dd21..d1039da2f7795545152aec752f43029f35cc8912 100644
--- a/src/utils/utility.py
+++ b/src/utils/utility.py
@@ -91,6 +91,24 @@ def inv_minmax(x, xmin: float, xmax: float):
     return x * (xmax-xmin) + xmin
 
 
+def scale_standard(x, mean: float, std: float):
+    """Standard scaling, also unpacks torch tensors if needed
+
+    Args:
+        x (np.array,torch.Tensor or similar): array to inverse scaling
+        mean (float): mean value of a (training) dataset
+        std (float): standard deviation value of a (training) dataset
+
+    Returns:
+        (np.array,torch.Tensor or similar): minmax scaled values
+    """
+    if isinstance(x,torch.Tensor) and isinstance(mean, torch.Tensor) and isinstance(std, torch.Tensor):
+        return (x - mean) / std
+    if isinstance(mean, torch.Tensor):
+        mean = mean.item()
+    if isinstance(std, torch.Tensor):
+        std = std.item()
+    return (x - mean) / std
 
 def inv_standard(x, mean: float, std: float):
     """Inverse standard scaling, also unpacks torch tensors if needed