diff --git a/src/models/ensemble_models.py b/src/models/ensemble_models.py
index b245c971932b9c7985e6fd6afb5b24597ad7ce0d..f83c6845b7b80c8044732a8e017ae90213d85122 100644
--- a/src/models/ensemble_models.py
+++ b/src/models/ensemble_models.py
@@ -128,7 +128,7 @@ class WaVoLightningEnsemble(pl.LightningModule):
     """ Since the data for this is normalized using the scaler from the first model in model_list all models should be trained on the same or at least similar data i think?
     """
     # pylint: disable-next=unused-argument
-    def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate):
+    def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate,return_base=False):
         super().__init__()
 
         self.model_list = model_list
@@ -139,7 +139,7 @@ class WaVoLightningEnsemble(pl.LightningModule):
         self.scaler = model_list[0].scaler#TODO scaler richtig aus yaml laden -> als parameter übergeben, wie beim normalen modell.
         assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx"
         self.model_architecture = 'ensemble'
-
+        self.return_base = return_base
 
         #self.weighting = WeightingModelAttention(
         self.weighting = WeightingModel(
@@ -154,11 +154,14 @@ class WaVoLightningEnsemble(pl.LightningModule):
 
         #TODO different insizes
         #TODO ModuleList?!
-        self.save_hyperparameters(ignore=['model_list','scaler'])
+        self.save_hyperparameters(ignore=['model_list','scaler','return_base'])
         self.save_hyperparameters({"scaler": pickle.dumps(self.scaler)})
         for model in self.model_list:
             model.freeze()
 
+    def set_return_base(self,return_base):
+        self.return_base = return_base
+
     def _common_step(self, batch, batch_idx,dataloader_idx=0):
         """
         Computes the weighted average of the predictions of the models in the ensemble.
@@ -184,7 +187,8 @@ class WaVoLightningEnsemble(pl.LightningModule):
         y_hats = torch.stack(y_hat_list,axis=1)
         w = self.weighting(x)
         y_hat = torch.sum((y_hats*w),axis=1)
-
+        if self.return_base:
+            return y_hat, y_hats
         return y_hat
 
 
diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py
index a44a132ad4478acc44941b3eeb37844262684578..9dfddb7fd753cf65ba12b35109fd8bfbf6eac645 100644
--- a/src/utils/db_tools/db_tools.py
+++ b/src/utils/db_tools/db_tools.py
@@ -129,19 +129,11 @@ class OracleWaVoConnection:
         input_meta_data = self._ext_fcst_metadata(gauge_config)
 
         if len(self.times)> 1: 
-            df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created)
+            df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created,series=True)
+            df_ext = self.load_input_ext_fcst(gauge_config["external_fcst"]) #load external forecasts
 
-            #load external forecasts
-            if self.main_config.get("fake_external"):
-                df_temp = None
-            else:
-                stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(gauge_config["external_fcst"]),
-                                                    between(InputForecasts.tstamp, self.times[0], self.times[-1]),
-                                                    InputForecasts.member.in_(self.members))
-
-                df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
                     
-            dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp)
+            dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_ext,model.differencing)
             df_pred = pred_multiple_db(model, dataset)
             df_pred["created"] = created
             df_pred["sensor_name"] = gauge_config["gauge"]
@@ -151,7 +143,6 @@ class OracleWaVoConnection:
                 PegelForecasts.model==gauge_config["model_folder"].name,
                 PegelForecasts.sensor_name==gauge_config["gauge"],
                 PegelForecasts.tstamp.in_(df_pred["tstamp"]),
-                #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()),
                 )
                                                                                                                             
             df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
@@ -231,37 +222,12 @@ class OracleWaVoConnection:
             None
         """
         try:
-            if member == -1:
-                df_input = df_base_input.fillna(0)
-            else:
-                # replace fake forecast with external forecast
-                df_input = df_base_input.copy()
-                for input_meta in input_meta_data:
-                    if input_meta.ensemble_members == 1:
-                        temp_member = 0
-                    elif input_meta.ensemble_members == 21:
-                        temp_member = member
-                    else:
-                        temp_member = member #TODO was hier tun?
-
-                    fcst_data = df_temp[(df_temp["member"]== temp_member) & (df_temp["sensor_name"] == input_meta.sensor_name)]
-                    if len(fcst_data) == 0:
-                        raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}")
-
-                    # Replace missing values with forecasts.
-                    # It is guaranteed that the only missing values are in the "future" (after end_time)
-                    assert df_input[:-48].isna().sum().sum() == 0
-                    nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48)
-                    nan_indices = nan_indices[nan_indices >= 0]
-                    nan_indices2 = nan_indices + (144 - 48)
-                    col_idx = df_input.columns.get_loc(input_meta.sensor_name)
-                    df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices]
-
+            df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_temp)
         except LookupError as e:
             logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]})
             y = torch.Tensor(48).fill_(np.nan)
         else:
-            y = pred_single_db(model, df_input)
+            y = pred_single_db(model, df_input,differencing=model.differencing)
         finally:
             self.insert_forecast(
                 y,gauge_config,end_time, member, created
@@ -274,7 +240,10 @@ class OracleWaVoConnection:
         Args:
             ensemble_config (dict): The configuration for the ensemble, model_folder, ensemble_name, ensemble_model.
         """
+        #TODO add logging here
+        #TODO zrxp exporte?
         created = datetime.now(timezone.utc)
+        # Collect the base models
         stmt = select(Modell).where(Modell.modellname.in_(bindparam("model_names",expanding=True)))
         df_models = pd.read_sql(sql=stmt,con=self.engine,params={"model_names": ensemble_config["model_names"]})
         
@@ -285,54 +254,113 @@ class OracleWaVoConnection:
             model_list.append(model_path)
             pseudo_model_names.append(model_path.name)
 
-        stmt = select(PegelForecasts).where(
-            PegelForecasts.tstamp.in_(bindparam("times",expanding=True)),
-            PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True)),
-        )
-        #self.times
-        #            SensorData.sensor_name.in_(bindparam("model_columns", expanding=True)),
-        df_ens = pd.read_sql(
-            sql=stmt,
-            con=self.engine,
-            index_col="tstamp",
-            params={
-                "times": self.times,
-                "pseudo_model_names": pseudo_model_names,
-            },
-        )
-        fake_config = {'gauge': df_ens.iloc[0,0],
-                       'model_folder' : ensemble_config['ensemble_name']}
+        fake_config = self._generate_fake_config(df_models,ensemble_config)
+        input_meta_data = self._ext_fcst_metadata(fake_config)
+        df_ext = self.load_input_ext_fcst(fake_config["external_fcst"]) #load external forecasts
+
+        #load ensemble model, maybe with dummys
+        if ensemble_config["mode"] != 'naiv':
+            model, ensemble_yaml = hp.load_ensemble_model(ensemble_config["ensemble_model"],model_list,dummy=ensemble_config["mode"] == 'only_ensemble')
+            model.set_return_base(ensemble_config["mode"] == 'save_both')
+            differencing = ensemble_yaml["differencing"]
+
         
-        if ensemble_config["ensemble_model"] is not None:
-            base_model = hp.load_model(model_list[0])
+        #save_both -> calculate the base models and the ensemble model and save both
+        #calc_both> calculate the base models and the ensemble model but only save the ensemble model
+
+        #Warning: This behaves slightly differently from the other versions.
+        #Here the external forecast only overwrite their corresponding sensor values if they are nan.
+        #This means the model has more information for older timestamps.
+        if ensemble_config["mode"] == 'only_ensemble' or ensemble_config["mode"] == 'naiv':
+            stmt = select(PegelForecasts).where(
+                PegelForecasts.tstamp.in_(bindparam("times",expanding=True)),
+                PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True)))
+            df_ens = pd.read_sql(sql=stmt,con=self.engine,index_col="tstamp",
+                params={"times": self.times,"pseudo_model_names": pseudo_model_names})
+
+            #self.insert_forecast(y,fake_config,cur_time, member, created)
+            #TODO this could be much more efficient if i collected the forecasts i think
+            if ensemble_config["mode"] == 'naiv': # -> load the results of the base models from the database and take the naive mean
+                for cur_time in df_ens.index.unique():
+                    df_y_time_slice = df_ens.loc[cur_time]
+                    for member in self.members:
+                        df_y_member = df_y_time_slice[df_y_time_slice["member"] == member]
+                        y = df_y_member.iloc[:,4:].mean()
+                        y = torch.Tensor(y)
+                        self.insert_forecast(y,fake_config,cur_time, member, created)
+
+            if ensemble_config["mode"] == 'only_ensemble': # -> load the results of the base models from the database
+                for cur_time in df_ens.index.unique():
+                    df_y_time_slice = df_ens.loc[cur_time]
+                    df_base_input = self.load_input_db(model.max_in_size,cur_time,fake_config,created)
+                    for member in self.members:
+                        df_y_member = df_y_time_slice[df_y_time_slice["member"] == member]
+                        y_ens = df_y_member.iloc[:,4:].values
+                        try:
+                            x = self._get_x(df_base_input,member,input_meta_data,df_ext,differencing,model)
+                        except LookupError as e:
+                            logging.error(e.args[0],extra={"gauge":fake_config["gauge"]})
+                            y = torch.Tensor(48).fill_(np.nan)
+                        else:
+                            w = model.weighting(x)
+                            y = torch.sum((torch.tensor(y_ens)*w),axis=1)[0]
+                        finally:
+                            self.insert_forecast(y,fake_config,cur_time, member, created)
+            
+        elif ensemble_config["mode"] == 'save_both' or ensemble_config["mode"] == 'calc_both':
+            #df_base_input_all = self.load_input_db(model.max_in_size,self.times[-1],fake_config,created)
+
+            df_base_input = self.load_input_db(model.max_in_size, self.times[-1], fake_config,created,series=True)
+            ds = self._make_dataset(fake_config,df_base_input, model,input_meta_data,df_ext,ensemble_yaml["differencing"])
+            #df_pred = pred_multiple_db(model, dataset)
+            data_loader = DataLoader(ds, batch_size=256,num_workers=1)
+            pred = hp.get_pred(model, data_loader,None)
+
+            df_base_preds = []                
+            df_template = ds.get_whole_index()
+            if ensemble_config["mode"] == 'save_both':
+                pred, base_preds = pred
+                for j in range(base_preds.shape[1]):
+                    df_temp = pd.concat([df_template, pd.DataFrame(base_preds[:,j,:],columns=[f"h{i}" for i in range(1,49)])], axis=1)
+                    df_temp["model"] = model_list[j].name
+                    df_base_preds.append(df_temp)
+                    
+            df_temp = pd.concat([df_template, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1)
+            df_temp["model"] = fake_config["model_folder"].name
+            df_base_preds.append(df_temp)
 
+            df_pred = pd.concat(df_base_preds)
+            df_pred["created"] = created
+            df_pred["sensor_name"] = fake_config["gauge"]
 
-            ensemble_folder = Path(ensemble_config["ensemble_model"])
-            ensemble_yaml = hp.load_settings_model(ensemble_folder)
 
-            from models.ensemble_models import WaVoLightningEnsemble
+            #pseudo_model_names
+            stmt =select(PegelForecasts.tstamp,PegelForecasts.member,PegelForecasts.model).where(
+                #PegelForecasts.model==fake_config["model_folder"].name,
+                #PegelForecasts.model.in_(pseudo_model_names + [fake_config["model_folder"].name]),
+                PegelForecasts.model.in_(df_pred["model"].unique()),
+                PegelForecasts.sensor_name==fake_config["gauge"],
+                PegelForecasts.tstamp.in_(df_pred["tstamp"].unique()),
+                )
+                                                                                                                            
+            df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
+            merged = df_pred.merge(df_already_inserted[['tstamp', 'member','model']],on=['tstamp', 'member','model'], how='left',indicator=True)
+            df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1)
+            df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1)
+            #sqlalchemy.orm.exc.StaleDataError: UPDATE statement on table 'pegel_forecasts' expected to update 13860 row(s); 12600 were matched.
+            with Session(self.engine) as session:
+                forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist()
+                session.add_all(forecast_objs)
+                session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records'))
+                session.commit()
 
-            checkpoint_path = next((ensemble_folder / "checkpoints").iterdir())
-            temp_model_list = len(ensemble_yaml["model_path_list"])*[base_model]
-            model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=temp_model_list)
-            #model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=None)
 
 
 
-        for cur_time in df_ens.index.unique():
-            df_time_slice = df_ens.loc[cur_time]
-            for member in self.members:
-                if ensemble_config["ensemble_model"] is None:
-                    df_member = df_time_slice[df_time_slice["member"] == member]
-                    y = df_member.iloc[:,4:].mean()
-                    y = torch.Tensor(y)
-                else:
-                    y = None
-        
-                self.insert_forecast(y,fake_config,cur_time, member, created)
 
 
-    def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame:
+
+    def load_input_db(self, in_size, end_time, gauge_config,created,series=False) -> pd.DataFrame:
         """
         Loads input data from the database based on the current configuration.
 
@@ -341,6 +369,7 @@ class OracleWaVoConnection:
             end_time (pd.Timestamp): The end time of the input data.
             gauge_config (dict): The configuration for the gauge, important here are the columns and external forecasts.
             created (pd.Timestamp): The timestamp of the start of the forecast creation.
+            series (bool): If True, the input data is loaded all at once, this doesn't work with the standard way, because the forecasts would be available to models that shouldn have them yet.
 
         Returns:
             pd.DataFrame: The loaded input data as a pandas DataFrame.
@@ -349,7 +378,7 @@ class OracleWaVoConnection:
         columns=gauge_config["columns"]
         external_fcst=gauge_config["external_fcst"]
         
-        if len(self.times) > 1:
+        if len(self.times) > 1 and series:
             first_end_time = self.times[0]
             start_time = first_end_time - pd.Timedelta(in_size - 1, "hours")
             if self.main_config.get("fake_external"):
@@ -452,6 +481,62 @@ class OracleWaVoConnection:
 
         return df_input
 
+    def load_input_ext_fcst(self,external_fsct = None) -> pd.DataFrame:
+        if self.main_config.get("fake_external"):
+            return None
+        
+        if external_fsct is None:
+            external_fsct =  []
+
+        stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(external_fsct),
+                                            between(InputForecasts.tstamp, self.times[0], self.times[-1]),
+                                            InputForecasts.member.in_(self.members))
+
+        df_ext = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
+        return df_ext
+
+    def fill_base_input(self,member,df_base_input,input_meta_data,df_ext) -> pd.DataFrame:
+        """This function fills the base input data with external forecasts depending on the current (available) member.
+
+        Args:
+            member (_type_): member identifier (-1-21)
+            df_base_input (_type_): sensor data from the database
+            input_meta_data (_type_): metadata for the external forecasts
+            df_ext (_type_): external forecast data
+
+        Returns:
+            pd.DataFrame: Input data with external forecasts inserted
+        """
+        if member == -1:
+            df_input = df_base_input.fillna(0)
+        else:
+            # replace fake forecast with external forecast
+            df_input = df_base_input.copy()
+            for input_meta in input_meta_data:
+                if input_meta.ensemble_members == 1:
+                    temp_member = 0
+                elif input_meta.ensemble_members == 21:
+                    temp_member = member
+                else:
+                    temp_member = member #TODO was hier tun?
+
+                fcst_data = df_ext[(df_ext["member"]== temp_member) & (df_ext["sensor_name"] == input_meta.sensor_name)]
+                if len(fcst_data) == 0:
+                    raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}")
+
+                # Replace missing values with forecasts.
+                # It is guaranteed that the only missing values are in the "future" (after end_time)
+                assert df_input[:-48].isna().sum().sum() == 0
+                nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48)
+                nan_indices = nan_indices[nan_indices >= 0]
+                nan_indices2 = nan_indices + (144 - 48)
+                col_idx = df_input.columns.get_loc(input_meta.sensor_name)
+                df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices]
+                
+        return df_input
+
+
+
 
     def add_zrxp_data(self, zrxp_folders: List[Path]) -> None:
         """Adds zrxp data to the database
@@ -660,6 +745,32 @@ class OracleWaVoConnection:
             #try:
             session.commit()
 
+    def _generate_fake_config(self,df_models,ensemble_config):
+        model_id = df_models.iloc[0,0].item()
+
+        query = (
+            select(ModellSensor)
+            .where(ModellSensor.modell_id == model_id)
+            .order_by(ModellSensor.ix)
+        )
+        df_ms = pd.read_sql(query,self.engine)
+
+        fake_config = {'gauge': df_models["target_sensor_name"].values[0],
+                    'model_folder' : ensemble_config['ensemble_name'],
+                    'external_fcst' : list(df_ms.dropna()["sensor_name"]), #drop the rows without a value in vhs_gebiet,
+                    'columns' : list(df_ms["sensor_name"])}
+        return fake_config
+    
+    def _get_x(self,df_base_input,member,input_meta_data,df_ext,differencing,model): 
+        df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_ext)
+        with warnings.catch_warnings():
+            warnings.filterwarnings(action="ignore", category=UserWarning)
+            if differencing == 1:
+                df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
+            x = model.scaler.transform(df_input.values)
+        x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
+        return x
+
 
     def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name):
         if self.main_config.get("export_zrxp") and member == 0:
@@ -704,10 +815,10 @@ class OracleWaVoConnection:
 
 
 
-    def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext):
+    def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext,differencing):
         with warnings.catch_warnings():
             warnings.filterwarnings(action="ignore", category=UserWarning)
-            if model.differencing == 1:
+            if differencing == 1:
                 df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
             #TODO time embedding stuff
             df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns)
@@ -876,20 +987,21 @@ class OracleWaVoConnection:
 
 
 @torch.no_grad()
-def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor:
+def pred_single_db(model, df_input: pd.DataFrame,differencing=True) -> torch.Tensor:
     """
     Predicts the output for a single input data point.
 
     Args:
         model (LightningModule): The model to use for prediction.
         df_input (pandas.DataFrame): Input data point as a DataFrame.
+        differencing (bool): If True a column is inserted at the last place with differencing of the target value
 
     Returns:
         torch.Tensor: Predicted output as a tensor.
     """
     with warnings.catch_warnings():
         warnings.filterwarnings(action="ignore", category=UserWarning)
-        if model.differencing == 1:
+        if differencing == 1:
             df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
         #TODO time embedding stuff
         x = model.scaler.transform(
diff --git a/src/utils/helpers.py b/src/utils/helpers.py
index ce591b6212f15526f1ab9b144d23afc0bbd349d9..23c293da67985b78a4078cbf287cf7e92d1bd3d1 100644
--- a/src/utils/helpers.py
+++ b/src/utils/helpers.py
@@ -16,6 +16,7 @@ from torch.utils.data import DataLoader
 # from data_tools.data_module import WaVoDataModule
 from data_tools.data_module import WaVoDataModule
 from models.lightning_module import WaVoLightningModule
+from models.ensemble_models import WaVoLightningEnsemble
 
 
 def load_settings_model(model_dir: Path) -> dict:
@@ -90,11 +91,17 @@ def get_pred(
     if move_model:
         model.cuda()
 
-    y_pred = np.concatenate(pred)
+    if getattr(model, 'return_base', None):
+        pred, base_preds = pred[0] #TODO this function is not tested with more than one batch 
+        return pred, base_preds
 
-    if y_true is not None:
-        y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred?
-        y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1]))
+    else:
+            
+        y_pred = np.concatenate(pred)
+
+        if y_true is not None:
+            y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred?
+            y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1]))
     return y_pred
 
 
@@ -211,3 +218,36 @@ def load_modules_rl(
     data_module.setup(stage="")
 
     return model_list, data_module
+
+
+def load_ensemble_model(model_dir, model_dir_list,dummy=False):
+    """Loads a WaVoLightningEnsemble from the given folder, also need a list of the models that are part of the ensemble.
+
+    Args:
+        model_dir : Path to the folder containing the ensemble model
+        model_dir_list : List of paths to the models that are part of the ensemble
+        dummy (bool, optional): If True only the first base_model in model_dir_list is actually loaded. Defaults to False.
+
+    Returns:
+        _type_: model and ensemble_yaml with infos about that model
+    """
+    model_dir = Path(model_dir)
+    ensemble_yaml = load_settings_model(model_dir)
+
+    if dummy:
+        base_model = load_model(model_dir=model_dir_list[0])
+        model_list = len(model_dir_list)*[base_model]
+    else:
+        model_list = [load_model(model_dir=Path(p)) for p in model_dir_list]
+
+    for model in model_list:
+        model.eval()
+
+    checkpoint_path = next((model_dir / "checkpoints").iterdir())
+    model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=model_list)
+
+    model.gauge_idx = model_list[0].gauge_idx #hacky
+    model.in_size = model.max_in_size
+    model.eval()
+    return model, ensemble_yaml
+