diff --git a/configs/db_new.yaml b/configs/db_new.yaml
index dd07f1acb2c833fe4b7d033e2a46459c8d13e822..4390bb03487bb788e13d738b7b6e3c08f603331d 100644
--- a/configs/db_new.yaml
+++ b/configs/db_new.yaml
@@ -13,7 +13,6 @@ main_config:
   export_zrxp : True
   #if True tries to put  all files in folder in the external forecast table
   fake_external : False
-  load_sensor : False
   load_zrxp : True
   ensemble : True
   single : True
@@ -21,55 +20,274 @@ main_config:
   #start: 2019-12-01 02:00
   #start: 2023-04-19 03:00
   #start: 2023-11-19 03:00
-  start: 2024-09-13 09:00
-  end: !!str 2024-09-14 09:00
+  start: 2024-09-13 06:00
+  end: !!str 2024-09-13 21:00
   #end: !!str 2021-02-06
   # if range is true will try all values from start to end for predictions
   range: !!bool True
   #range: !!bool False
-  zrxp: !!bool True
 gauge_configs:
-  - gauge: 112211,S,5m.Cmd
-    model_folder: ../models_torch/hollingstedt_version_37
+  - gauge: 114135,S,60m.Cmd
+    model_folder: ../models_torch/db_testing/lightning_logs/version_11
     columns:
-      - 4466,SHum,vwsl 
-      - 4466,SHum,bfwls 
-      - 4466,AT,h.Cmd-2 
-      - 114069,S,15m.Cmd 
-      - 111111,S,5m.Cmd 
-      - 9520081,S_Tide,1m.Cmd 
-      - 9530010,S_Tide,1m.Cmd 
-      - 112211,Precip,h.Cmd 
-      - 112211,S,5m.Cmd
+      - 7427,SHum,vwsl
+      - 7427,SHum,bfwls
+      - 7427,AT,h.Mean
+      - 114200,S,60m.Cmd
+      - 114131,S,60m.Cmd
+      - 114135,Precip,h.Cmd
+      - 114135,S,60m.Cmd
     external_fcst:
-      - 9530010,S_Tide,1m.Cmd
-      - 112211,Precip,h.Cmd
-  - gauge: 114547,S,60m.Cmd
-    model_folder: ../models_torch/tarp_version_49/
+      - 114135,Precip,h.Cmd
+  - gauge: 114135,S,60m.Cmd
+    model_folder: ../models_torch/base_models_ensembles_willenscharen/version_37
     columns:
-      - 4466,SHum,vwsl
-      - 4466,SHum,bfwls
-      - 4466,AT,h.Cmd
-      - 114435,S,60m.Cmd
-      - 114050,S,60m.Cmd
-      - 114547,S,60m.Cmd
-      - 114547,Precip,h.Cmd
+      - 7427,SHum,vwsl
+      - 7427,SHum,bfwls
+      - 7427,AT,h.Mean
+      - 114200,S,60m.Cmd
+      - 114131,S,60m.Cmd
+      - 114135,Precip,h.Cmd
+      - 114135,S,60m.Cmd
     external_fcst:
-      - 114547,Precip,h.Cmd
-      #114547,Precip,h.Cmd : Tarp
-  - gauge: 114069,Precip,h.Cmd
-    model_folder: ../models_torch/version_74_treia/
-    columns:
-      - 4466,SHum,vwsl
-      - 4466,SHum,bfwls
-      - 4466,AT,h.Cmd
-      - 114435,S,60m.Cmd
-      - 114050,S,60m.Cmd
-      - 114547,S,60m.Cmd
-      - 114224,S,60m.Cmd
-      - 114061,S,60m.Cmd
-      - 114069,Precip,h.Cmd
-      - 114069,S,15m.Cmd
-    external_fcst:
-      - 114069,Precip,h.Cmd
-      #114069,Precip,h.Cmd : Treia
\ No newline at end of file
+      - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_88
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_106
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_110
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_116
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_137
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_171
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_175
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_181
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/base_models_ensembles_willenscharen/version_186
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+##############
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/willenscharen_version_37/
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 114135,S,60m.Cmd
+  #  model_folder: ../models_torch/willenscharen_version_94/
+  #  columns:
+  #    - 7427,SHum,vwsl
+  #    - 7427,SHum,bfwls
+  #    - 7427,AT,h.Mean
+  #    - 114200,S,60m.Cmd
+  #    - 114131,S,60m.Cmd
+  #    - 114135,Precip,h.Cmd
+  #    - 114135,S,60m.Cmd
+  #  external_fcst:
+  #    - 114135,Precip,h.Cmd
+  #- gauge: 112211,S,5m.Cmd
+  #  model_folder: ../models_torch/hollingstedt_version_37
+  #  columns:
+  #    - 4466,SHum,vwsl 
+  #    - 4466,SHum,bfwls 
+  #    - 4466,AT,h.Cmd-2 
+  #    - 114069,S,15m.Cmd 
+  #    - 111111,S,5m.Cmd 
+  #    - 9520081,S_Tide,1m.Cmd 
+  #    - 9530010,S_Tide,1m.Cmd 
+  #    - 112211,Precip,h.Cmd 
+  #    - 112211,S,5m.Cmd
+  #  external_fcst:
+  #    - 9530010,S_Tide,1m.Cmd
+  #    - 112211,Precip,h.Cmd
+  #- gauge: 114547,S,60m.Cmd
+  #  model_folder: ../models_torch/tarp_version_49/
+  #  columns:
+  #    - 4466,SHum,vwsl
+  #    - 4466,SHum,bfwls
+  #    - 4466,AT,h.Cmd
+  #    - 114435,S,60m.Cmd
+  #    - 114050,S,60m.Cmd
+  #    - 114547,S,60m.Cmd
+  #    - 114547,Precip,h.Cmd
+  #  external_fcst:
+  #    - 114547,Precip,h.Cmd
+  #    #114547,Precip,h.Cmd : Tarp
+  #- gauge: 114069,Precip,h.Cmd
+  #  model_folder: ../models_torch/version_74_treia/
+  #  columns:
+  #    - 4466,SHum,vwsl
+  #    - 4466,SHum,bfwls
+  #    - 4466,AT,h.Cmd
+  #    - 114435,S,60m.Cmd
+  #    - 114050,S,60m.Cmd
+  #    - 114547,S,60m.Cmd
+  #    - 114224,S,60m.Cmd
+  #    - 114061,S,60m.Cmd
+  #    - 114069,Precip,h.Cmd
+  #    - 114069,S,15m.Cmd
+  #  external_fcst:
+  #    - 114069,Precip,h.Cmd
+  #    #114069,Precip,h.Cmd : Treia
+ensemble_configs:      
+      #save_both -> calculate the base models and the ensemble model and save both
+      #calc_both -> calculate the base models and the ensemble model but only save the ensemble model
+      #only_ensemble -> load the results of the base models from the database
+      #naiv -> not implemented, should just ignore the weighting part of the ensemble_model. 
+     # Those to not actually need to be path. The foldername is sufficient.
+    - model_names:
+      #- /home/mspils/Waterlevel/models_torch/willenscharen_version_94
+      #- /home/mspils/Waterlevel/models_torch/willenscharen_version_37
+        - base_version_37
+        - base_version_88
+        - base_version_106
+        - base_version_110
+        - base_version_116
+        - base_version_137
+        - base_version_171
+        - base_version_175
+        - base_version_181
+        - base_version_186
+      ensemble_name: calc_both_ensemble_willenscharen_0
+      ensemble_model : /home/mspils/Waterlevel/models_torch/ensemble_willenscharen_version_0
+      mode : calc_both
+    #- model_names:
+    #    - base_version_37
+    #    - base_version_88
+    #    - base_version_106
+    #    - base_version_110
+    #    - base_version_116
+    #    - base_version_137
+    #    - base_version_171
+    #    - base_version_175
+    #    - base_version_181
+    #    - base_version_186
+    #  ensemble_name: save_both_willenscharen_0
+    #  ensemble_model : /home/mspils/Waterlevel/models_torch/ensemble_willenscharen_version_0
+    #  mode : save_both
+    #- model_names:
+    #    - base_version_37
+    #    - base_version_88
+    #    - base_version_106
+    #    - base_version_110
+    #    - base_version_116
+    #    - base_version_137
+    #    - base_version_171
+    #    - base_version_175
+    #    - base_version_181
+    #    - base_version_186
+    #  ensemble_name: only_ensemble_willenscharen_0
+    #  ensemble_model : /home/mspils/Waterlevel/models_torch/ensemble_willenscharen_version_0
+    #  mode : only_ensemble
+    #- model_names:
+    #    - base_version_37
+    #    - base_version_88
+    #    - base_version_106
+    #    - base_version_110
+    #    - base_version_116
+    #    - base_version_137
+    #    - base_version_171
+    #    - base_version_175
+    #    - base_version_181
+    #    - base_version_186
+    #  ensemble_name: naiv_willenscharen_0
+    #  mode : naiv
diff --git a/configs/db_train.yaml b/configs/db_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9716e86896da83d65248a43e3c12df9240b37cdf
--- /dev/null
+++ b/configs/db_train.yaml
@@ -0,0 +1,118 @@
+
+general:
+  n_trials: 1
+  study_name: "my_study" #Currently only in memory
+  log_dir : ../models_torch/db_testing/
+  #save to database?
+  save_all_to_db: False
+  save_best_to_db: True 
+  model_name_prefix: "db_test_willenscharen"
+
+# Database configuration
+db_params:
+  user: c##mspils
+  #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen.
+  password: cobalt_deviancy
+  dsn: localhost/XE
+  #optional parameter: config_dir
+
+
+data:
+  batch_size: 512
+  num_workers: 7
+  sensors:
+    - 7427,SHum,vwsl
+    - 7427,SHum,bfwls
+    - 7427,AT,h.Mean
+    - 114200,S,60m.Cmd
+    - 114131,S,60m.Cmd
+    - 114135,Precip,h.Cmd
+    - 114135,S,60m.Cmd
+  external_fcst:
+    - 114135,Precip,h.Cmd
+  target_column: 114135,S,60m.Cmd
+  #if percentile is above 1 it will be treated as a value instead of using the quantile
+  percentile : 0.95
+  # Time ranges (optional, will use defaults 8 years, 1 year and 1 year if not provided)
+  time_ranges:
+    train:
+      start: 2014-01-01 03:00:00
+      end: 2021-09-25 02:00:00
+    val:
+      start: 2021-09-25 02:00:00
+      end: 2023-05-23 05:00:00
+    test:
+      start: 2023-05-23 05:00:00
+      end: 2025-01-17 08:00:00
+    #train:
+    #  start: 2024-08-01
+    #  end: 2024-08-30
+    #val:
+    #  start: 2024-09-01
+    #  end: 2024-09-30
+    #test:
+    #  start: 2024-10-01
+    #  end: 2024-10-30
+
+# Training configuration
+training:
+  max_epochs: 10
+  log_every_n_steps: 10
+  monitor: "hp/val_loss"
+  
+
+
+# Hyperparameter search spaces
+hyperparameters:
+  differencing:
+    type: int
+    min: 1
+    max: 1
+  
+  learning_rate:
+    type: float
+    #min: 0.00001
+    #max: 0.01
+    min: 0.0006861717141475389
+    max: 0.0006861717141475389
+  optimizer:
+    type: categorical
+    values: ["adam"]
+    
+  model_architecture:
+    type: categorical
+    values: ["classic_lstm"]
+    
+  # Model specific parameters
+  classic_lstm:
+    hidden_size_lstm:
+      type: int
+      #min: 32
+      #max: 512
+      min: 131
+      max: 131
+
+    num_layers_lstm:
+      type: int
+      min: 1
+      max: 1
+      #2
+    hidden_size:
+      type: int
+      #min: 32
+      #max: 512
+      min: 179
+      max: 179
+    num_layers:
+      type: int
+      min: 3
+      max: 3
+      #min: 2
+      #max: 4
+    dropout:
+      type: float
+      #min: 0
+      #max: 0.5
+      min: 0.1830338279898485
+      max: 0.1830338279898485
+
diff --git a/src/app.py b/src/app.py
index 80a7f04dd921f51fcc192b10b6746171d5115071..002888ae3e459a92d3030484a5962b9eb2d65b89 100644
--- a/src/app.py
+++ b/src/app.py
@@ -23,7 +23,7 @@ from utils.db_tools.orm_classes import (InputForecasts, Log, Metric, Modell,
                                         ModellSensor, PegelForecasts,
                                         SensorData)
 
-
+debug = False
 
 #TODO add logging to dashboard?
 class ForecastMonitor:
@@ -86,7 +86,8 @@ class ForecastMonitor:
                                 width="auto"),
                             dbc.Col(dbc.Tabs([
                                     dbc.Tab(label="Modell Monitoring",tab_id="model-monitoring",**TAB_STYLE),
-                                    dbc.Tab(label="System Logs",tab_id="system-logs",**TAB_STYLE)
+                                    dbc.Tab(label="System Logs",tab_id="system-logs",**TAB_STYLE),
+                                    dbc.Tab(label="Modell  Metrics",tab_id="model-metrics",**TAB_STYLE)
                                     ],
                                     id="main-tabs",
                                     active_tab="model-monitoring",
@@ -176,39 +177,46 @@ class ForecastMonitor:
             ])
         ], fluid=True, className="py-4")
 
-        # Add a container to hold the tab content
-        tab_content = html.Div(
-            id="tab-content",
-            children=[
-                dbc.Container(
-                    id="page-content",
-                    children=[],
-                    fluid=True,
-                    className="py-4"
-                )
-            ]
-        )
-        # Main content with tabs
-        #main_content = dbc.Tabs([
-        #    dbc.Tab(model_monitoring_tab, label="Modell Monitoring", tab_id="model-monitoring"),
-        #    dbc.Tab(system_logs_tab, label="System Logs", tab_id="system-logs")
-        #], id="main-tabs", active_tab="model-monitoring")
+
+        # System Logs Tab Content
+        self.model_metrics_tab = dbc.Container([
+            dbc.Card([
+                dbc.CardBody([
+                    html.H3("Modell Metriken", className="mb-4"),
+                    dash_table.DataTable(
+                        id='metric-table',
+                        columns=[
+                            {'name': 'Zeitstempel', 'id': 'timestamp'},
+                            {'name': 'Level', 'id': 'level'},
+                            {'name': 'Pegel', 'id': 'sensor'},
+                            {'name': 'Nachricht', 'id': 'message'},
+                            {'name': 'Modul', 'id': 'module'},
+                            {'name': 'Funktion', 'id': 'function'},
+                            {'name': 'Zeile', 'id': 'line'},
+                            {'name': 'Exception', 'id': 'exception'}
+                        ],
+                        **TABLE_STYLE,
+                        page_size=20,
+                        page_action='native',
+                        sort_action='native',
+                        sort_mode='multi',
+                        filter_action='native',
+                    )
+                ])
+            ])
+        ], fluid=True, className="py-4")
 
         # Combine all elements
         self.app.layout = html.Div([
             header,
-            html.Div(id="tab-content", children=[
-                self.model_monitoring_tab,
-                self.system_logs_tab
-            ], style={"display": "none"}),  # Hidden container for both tabs
+            #html.Div(id="tab-content", children=[
+            #    self.model_monitoring_tab,
+            #    self.system_logs_tab
+            #], style={"display": "none"}),  # Hidden container for both tabs
             html.Div(id="visible-tab-content")  # Container for the visible tab
         ], className="min-vh-100 bg-light")
 
-        # Combine all elements
-        #self.app.layout = html.Div([
-        #    header,
-        #    main_content
-        #], className="min-vh-100 bg-light")
+
         
     def get_model_name_from_path(self, modelldatei):
         if not modelldatei:
@@ -361,7 +369,8 @@ class ForecastMonitor:
                 
                 # Get last 144 hours of sensor data
                 time_threshold = datetime.now(timezone.utc) - timedelta(hours=144) - BUFFER_TIME
-                time_threshold= pd.to_datetime("2024-09-13 14:00:00.000") - timedelta(hours=144) #TODO rausnehmen
+                if debug:
+                    time_threshold= pd.to_datetime("2024-09-13 14:00:00.000") - timedelta(hours=144)
 
                 stmt = select(SensorData).where(
                     SensorData.tstamp >= time_threshold,
@@ -400,8 +409,25 @@ class ForecastMonitor:
             print(f"Error getting model metrics: {str(e)}")
             return None, None
 
+    def get_confidence(self, model_id,flood=False,subset='test'):
+        """Get metrics for a specific model. Assumes that each model only has on set of metrics with the same name"""
+ 
+        try:
+            if flood:
+                metric_names = [f"{subset}_conf05_flood",f"{subset}_conf10_flood",f"{subset}_conf50_flood"]
+            else:
+                metric_names = [f"{subset}_conf05",f"{subset}_conf10",f"{subset}_conf50"]
 
+            stmt = select(Metric).where(Metric.model_id == model_id,Metric.metric_name.in_(metric_names))
+            df = pd.read_sql(sql=stmt, con=self.engine)
+            df = df.pivot(index="timestep",columns="metric_name",values="value")
             
+            return df
+        except Exception as e:
+            print(f"Error getting confidence interval: {str(e)}")
+            return None
+
+
 
     def setup_callbacks(self):
         # Add this new callback at the start of setup_callbacks
@@ -494,27 +520,28 @@ class ForecastMonitor:
             model_name = selected_row['model_name']
             actual_model_name = selected_row['actual_model_name']
             
-            # Get logs
+            #First collect all the data needed
             logs = self.get_recent_logs(sensor_name)
+            df_historical = self.get_historical_data(model_id)
+            sensor_names = list(df_historical.columns)
+            df_inp_fcst,ext_forecast_names = self.get_input_forecasts(sensor_names)
+            df_forecasts = self.get_recent_forecasts(actual_model_name)
+            df_conf = self.get_confidence(model_id,flood=False,subset='test')
+
+
+            # Format logs
             log_table = create_log_table(logs)
             log_view = html.Div([
                 html.H4(f"Neueste Logs für {model_name}"),
                 log_table
             ])
 
-            # Get historical data
-            df_historical = self.get_historical_data(model_id)
-            df_filtered = df_historical[df_historical.isna().any(axis=1)]
-
-            sensor_names = list(df_historical.columns)
-            
+            # Format historical data
             if not df_historical.empty:
                 fig = create_historical_plot(df_historical, model_name)
+                df_filtered = df_historical[df_historical.isna().any(axis=1)]
                 historical_table = create_historical_table(df_filtered)
-
-
                 historical_view = html.Div([
-                    #html.H4(f"Messdaten {model_name}"),
                     dcc.Graph(figure=fig),
                     html.H4("Zeitpunkte mit fehlenden Messdaten", style={'marginTop': '20px', 'marginBottom': '10px'}),
                     html.Div(historical_table, style={'width': '100%', 'padding': '10px'})
@@ -523,12 +550,10 @@ class ForecastMonitor:
                 historical_view = html.Div("Keine Messdaten verfügbar")
 
 
-            # Get forecast data
-            df_inp_fcst,ext_forecast_names = self.get_input_forecasts(sensor_names)
-            inp_fcst_table = create_inp_forecast_status_table(df_inp_fcst,ext_forecast_names)
-            df_forecasts = self.get_recent_forecasts(actual_model_name)
+            # Format forecast data
             if not df_forecasts.empty:
-                fig_fcst = create_model_forecasts_plot(df_forecasts,df_historical,df_inp_fcst)
+                inp_fcst_table = create_inp_forecast_status_table(df_inp_fcst,ext_forecast_names)
+                fig_fcst = create_model_forecasts_plot(df_forecasts,df_historical,df_inp_fcst,df_conf)
                 fcst_view = html.Div([
                     html.H4(f"Pegelvorhersage Modell {model_name}"),
                     dcc.Graph(figure=fig_fcst),
@@ -548,7 +573,7 @@ class ForecastMonitor:
                             id={'type': 'metric-selector', 'section': 'metrics'},
                             options=[
                                 {'label': metric.upper(), 'value': metric}
-                                for metric in ['mae', 'mse', 'nse', 'kge', 'p10', 'p20', 'r2', 'rmse', 'wape']
+                                for metric in ['mae', 'mse', 'nse', 'kge', 'p10', 'p20', 'r2', 'rmse', 'wape','conf05','conf10','conf50']
                             ],
                             value='mae',
                             clearable=False,
@@ -692,7 +717,7 @@ class ForecastMonitor:
                     return True, [html.Span(title, style={'flex': '1'}), html.Span("▼", className="ms-2")]
             return is_open, current_children
     
-    def run(self, host='0.0.0.0', port=8050, debug=True):
+    def run(self, host='0.0.0.0', port=8050):
         self.app.run(host=host, port=port, debug=debug)
 
 if __name__ == '__main__':
diff --git a/src/dashboard/tables_and_plots.py b/src/dashboard/tables_and_plots.py
index ee3925882d279ef84a374060fb24942f2cdcff6a..2c3d9c270c066cd85ce7bccf409ef7c72e2d7b01 100644
--- a/src/dashboard/tables_and_plots.py
+++ b/src/dashboard/tables_and_plots.py
@@ -17,6 +17,9 @@ from dashboard.styles import (FORECAST_STATUS_TABLE_STYLE,
                               get_forecast_status_conditional_styles)
 from dashboard.constants import PERIODS_EXT_FORECAST_TABLE
 
+def add_opacity(c,opacity):
+    return f"rgba{c[3:-1]},{opacity})"
+
 def create_status_summary_chart(status):
     """Creates a simple pie chart showing the percentage of models with current forecasts."""
     total_models = len(status['model_status'])
@@ -44,9 +47,9 @@ def create_status_summary_chart(status):
 
 
 
-def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
+def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst,df_conf):
     """Create a plot showing recent model forecasts alongside historical data"""
-    sensor_name = df_forecasts["sensor_name"][0]
+    sensor_name = df_forecasts["sensor_name"].values[0]
     start_times = sorted(set(df_forecasts["tstamp"]) | set(df_inp_fcst["tstamp"]))
     start_time_colors = {t: COLOR_SET[i % len(COLOR_SET)] for i, t in enumerate(start_times)}
 
@@ -79,11 +82,14 @@ def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
                 show_legend_check.add(start_time)
 
             legend_group = f'{start_time.strftime("%Y-%m-%d %H:%M")}'
-
+            
+            x = [start_time + timedelta(hours=i) for i in range(1, 49)]
+            y = [row[f'h{i}'] for i in range(1, 49)]
+            most_recent = start_time == df_forecasts["tstamp"].max()
             fig.add_trace(
                 go.Scatter(
-                    x=[start_time + timedelta(hours=i) for i in range(1, 49)],
-                    y=[row[f'h{i}'] for i in range(1, 49)],
+                    x=x,
+                    y=y,
                     name=legend_group,
                     legendgroup=legend_group,
                     showlegend=show_legend,
@@ -92,14 +98,47 @@ def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
                         width=3 if member == 0 else 1,
                     ),
                     hovertemplate=(
-                        'Time: %{x}<br>'
-                        'Value: %{y:.2f}<br>'
-                        f'Member: {member}<br>'
-                        f'Start: {start_time}<extra></extra>'
+                        f'Member: {member:2} - '
+                        'Value: %{y:.2f}'
                     ),
-                    visible= True if start_time == df_forecasts["tstamp"].max() else 'legendonly'
+                    visible= True if most_recent else 'legendonly'
                 ),
             )
+            if most_recent and member == 0 and not df_conf.empty:
+                for alpha,conf_name,opacity in zip(['05','10','50'],['95%','90%','50%'],[0.1,0.2,0.3]):
+                    #TODO test is hardcoded, suboptimal, flood metrics are not used
+                    subset = 'test'
+                    metric_name = f"{subset}_conf{alpha}"#if not flood else f"{subset}_conf{alpha}_flood"
+
+                    fillcolor = add_opacity(start_time_colors[start_time],opacity)
+
+                    fig.add_trace(
+                        go.Scatter(
+                            x=x,
+                            y=y+df_conf[metric_name],#.values,
+                            name=f"{conf_name} CI Hauptlauf",
+                            #legendgroup=legend_group,
+                            showlegend=False,
+                            mode='lines',
+                            line=dict(width=0),
+                            fillcolor=fillcolor,
+                            visible= True if most_recent else 'legendonly')
+                    )
+                    fig.add_trace(
+                        go.Scatter(
+                            x=x,
+                            y=y-df_conf[metric_name],#.values,
+                            name=f"{conf_name} CI Hauptlauf",
+                            #legendgroup=legend_group,
+                            #showlegend=True,
+                            mode='lines',
+                            line=dict(width=0),
+                            fill='tonexty',
+                            fillcolor=fillcolor,
+                            visible= True if most_recent else 'legendonly')
+                    )
+
+                
     temp_layout_dict = {"yaxis": {"title" : {"text":"Pegel [cm]"}}}
 
     fig.update_layout(**temp_layout_dict)
@@ -131,12 +170,10 @@ def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
             # Get unique forecast start times for color assignment        
             # For each member
             for member in members:
-                member_data = sensor_data[sensor_data['member'] == member]
-                
+                member_data = sensor_data[sensor_data['member'] == member]                
                 # For each forecast (row)
                 for _, row in member_data.iterrows():
                     start_time = row['tstamp']
-                    #legend_group = f'{sensor} {start_time.strftime("%Y%m%d%H%M")}'
                     legend_group = f'{start_time.strftime("%Y-%m-%d %H:%M")} {sensor}'
                     if start_time in show_legend_check:
                         show_legend = False
@@ -160,10 +197,8 @@ def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
                                 dash='solid' if member == 0 else 'dot'
                             ),
                             hovertemplate=(
-                                'Time: %{x}<br>'
+                                f'Member: {member:2} - '
                                 'Value: %{y:.2f}<br>'
-                                f'Member: {member}<br>'
-                                f'Start: {start_time}<extra></extra>'
                             ),
                             visible= True if start_time == df_forecasts["tstamp"].max() else 'legendonly',
                             yaxis=f"y{sensor_idx+1}"
@@ -175,13 +210,12 @@ def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst):
         #title='Model Forecasts',
         #xaxis_title='Time',
         #yaxis_title='',
-        #hovermode='x unified',
+        hovermode='x',#'x unified', #TODO Change when this is done: https://github.com/plotly/plotly.js/issues/4294
         legend=dict(
             yanchor="top",
             y=0.99,
             xanchor="left",
             x=1.05,
-            #bgcolor='rgba(255, 255, 255, 0.8)'
         )
     )
     return fig
diff --git a/src/data_tools/data_module.py b/src/data_tools/data_module.py
index a062a767f34da819195b6822d4c5521171e44cba..29ca9d3345a467d5ad7db362b1935709ae6b241f 100644
--- a/src/data_tools/data_module.py
+++ b/src/data_tools/data_module.py
@@ -6,12 +6,15 @@ from typing import Tuple, Union
 
 import lightning.pytorch as pl
 import pandas as pd
+from sqlalchemy import between, bindparam, select
 import torch
 from sklearn.preprocessing import StandardScaler
 from torch.utils.data import DataLoader
+from datetime import datetime, timedelta
 
 from data_tools.datasets import TimeSeriesDataSet, TimeSeriesDataSetExog # pylint: disable=import-error
 import utils.utility as ut # pylint: disable=import-error
+from utils.db_tools.orm_classes import SensorData #TODO consider moving DBWaVoDataModule to the db_tools folder
 
 
 class WaVoDataModule(pl.LightningDataModule):
@@ -104,7 +107,7 @@ class WaVoDataModule(pl.LightningDataModule):
 
         self.df_scaled = pd.DataFrame(self.scaler.transform(self.df),index=self.df.index,columns=self.df.columns)
 
-    def _load_df(self):
+    def _load_df(self) -> pd.DataFrame:
         df = pd.read_csv(self.filename, index_col=0, parse_dates=True)
         if not isinstance(df.index,pd.core.indexes.datetimes.DatetimeIndex):
             df = pd.read_csv(self.filename,index_col=0,parse_dates=True,date_format='%d.%m.%Y %H:00')
@@ -309,3 +312,166 @@ class WaVoDataModule(pl.LightningDataModule):
             #y = torch.Tensor(self.df_y[start:end].values[:,0])
 
             return TimeSeriesDataSet(x,y,self.in_size,self.out_size)
+
+
+class DBWaVoDataModule(pl.LightningDataModule):
+    def __init__(self, config, engine,differencing=1):
+        super().__init__()
+        self.config = config
+        self.engine = engine
+        self.save_hyperparameters(ignore=['engine', 'config'])
+        
+        # Get training parameters from config
+        self.in_size = 144
+        self.out_size = 48
+        #self.embed_time = False
+        self.percentile = config["percentile"]
+        self.batch_size = config['batch_size']
+        self.num_workers = config['num_workers']
+        self.differencing = differencing
+        self.sensors = config["sensors"]
+        self.external_fcst = config["external_fcst"]
+        self.gauge_column =  config["target_column"]
+        self.gauge_idx = self.sensors.index(self.gauge_column)
+        if self.differencing:
+            self.target_column = "d1"
+            self.target_idx = len(self.sensors)
+        else:
+            self.target_column = self.gauge_column
+            self.target_idx = self.gauge_idx
+        
+        # Set time ranges from config or use defaults
+        self.time_ranges = self._setup_time_ranges(config)
+        self.df_train = self._load_df(self.time_ranges['train']['start'], self.time_ranges['train']['end'])
+        self.df_val = self._load_df(self.time_ranges['val']['start'], self.time_ranges['val']['end'])
+        self.df_test = self._load_df(self.time_ranges['test']['start'], self.time_ranges['test']['end'])
+        self._set_threshold()
+        self.scaler = StandardScaler()
+        self.scaler.fit(self.df_train)
+        #TODO test hyperparameter saving
+        self.feature_count = self.df_train.shape[1]
+        self.save_hyperparameters(ignore=['scaler','time_ranges','config','engine'])
+        self.save_hyperparameters({
+            "scaler": pickle.dumps(self.scaler),
+            #"feature_count": self.feature_count,
+            "in_size": self.in_size,
+            "out_size": self.out_size,
+            "batch_size": self.batch_size,
+            "embed_time": False,
+            #"target_column": self.target_column,
+            #"target_idx": self.target_idx,
+            #"gauge_column": self.gauge_column,
+            #"gauge_idx": self.gauge_idx,
+            #"threshold": self.threshold,
+            #These aren't really hyperparameters, but maybe useful to have them logged
+            "train_start":str(self.time_ranges['train']['start']),
+            "train_end":str(self.time_ranges['train']['end']),
+            "val_start":str(self.time_ranges['val']['start']),
+            "val_end":str(self.time_ranges['val']['end']),
+            "test_start":str(self.time_ranges['test']['start']),
+            "test_end":str(self.time_ranges['test']['end']),
+            #"embed_time":False,
+            })
+        
+    def _set_threshold(self):
+        # behaves slightly different than the original code due to calculation the threshold only on the training data
+        if 0 < self.percentile < 1:
+            self.threshold = self.df_train[self.gauge_column].quantile(
+                self.percentile).item()
+        else:
+            self.threshold = self.percentile
+
+    def _load_df(self, start_time, end_time):
+        stmt = select(SensorData).where(
+            between(SensorData.tstamp, bindparam("start_time"), bindparam("end_time")),
+            SensorData.sensor_name.in_(bindparam("columns", expanding=True)),
+        )
+        df = pd.read_sql(
+            sql=stmt,
+            con=self.engine,
+            index_col="tstamp",
+            params={
+                "start_time": start_time,
+                "end_time": end_time + pd.Timedelta(hours=self.out_size),
+                "columns": self.sensors,
+            })
+        df = df.pivot(columns="sensor_name", values="sensor_value")[self.sensors]
+        for col in self.config["external_fcst"]:
+            df[col] = df[col].shift(-self.out_size)
+        df = df[:-48]
+        df = ut.fill_missing_values(df)
+        
+        if self.differencing == 1:
+            df[self.target_column] = df[self.gauge_column].diff()
+            df.iloc[0,-1] = 0
+        return df
+
+    def _setup_time_ranges(self, config):
+        now = datetime.now()
+        default_ranges = {
+            'train': {
+                'start': now - timedelta(days=365*10),
+                'end': now - timedelta(days=365*2)
+            },
+            'val': {
+                'start': now - timedelta(days=365*2),
+                'end': now - timedelta(days=365)
+            },
+            'test': {
+                'start': now - timedelta(days=365),
+                'end': now
+            }
+        }
+        
+        if 'time_ranges' in config:
+            time_ranges = config['time_ranges']
+            for split in ['train', 'val', 'test']:
+                if split in time_ranges:
+                    for bound in ['start', 'end']:
+                        if bound in time_ranges[split]:
+                            default_ranges[split][bound] = pd.to_datetime(time_ranges[split][bound])
+                            
+        return default_ranges
+        
+    def setup(self, stage=None):
+        x = torch.tensor(self.scaler.transform(self.df_train),dtype=torch.float32)
+        y = x[:,self.target_idx]
+        self.train_set = TimeSeriesDataSet(x,y,self.in_size,self.out_size)
+        x = torch.tensor(self.scaler.transform(self.df_val),dtype=torch.float32)
+        y = x[:,self.target_idx]
+        self.val_set = TimeSeriesDataSet(x,y,self.in_size,self.out_size)
+        x = torch.tensor(self.scaler.transform(self.df_test),dtype=torch.float32)
+        y = x[:,self.target_idx]
+        self.test_set = TimeSeriesDataSet(x,y,self.in_size,self.out_size)
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_set,
+            batch_size=self.batch_size,
+            shuffle=True,
+            num_workers=self.num_workers
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_set,
+            batch_size=self.batch_size,
+            shuffle=False,
+            num_workers=self.num_workers
+        )
+
+    def test_dataloader(self):
+        return DataLoader(
+            self.test_set,
+            batch_size=self.batch_size,
+            shuffle=False,
+            num_workers=self.num_workers
+        )
+    
+    def predict_dataloader(self):
+        return DataLoader(
+            self.train_set,
+            batch_size=self.batch_size,
+            shuffle=False,
+            num_workers=self.num_workers
+        )
\ No newline at end of file
diff --git a/src/import_model.py b/src/import_model.py
index cf40609089f652fad3af821c8a707ad0aa81d205..e16a6e4e2767d8916eb66e6296daa34315f4b8d0 100644
--- a/src/import_model.py
+++ b/src/import_model.py
@@ -175,13 +175,14 @@ class ModelImporter:
                     end_time = end_time.round("h")
                 if "flood" in col_name:
                     tag = "flood"
+                    col_name = col_name.replace("_flood","")
                 else:
                     tag = "normal"
                 
                 metric = Metric(model_id=model_id,metric_name=col_name,timestep=i+1,start_time=start_time,end_time=end_time,value=row,tag=tag)
                 metrics.append(metric)
 
-        session.add_all(metrics)
+        session.add_all(metrics)    
                 
     def _get_next_model_id(self, session: Session) -> float:
         """Get next available model ID"""
@@ -300,7 +301,7 @@ class ModelImporter:
             logging.info(msg=f"Updating {old_model} in_size to {model_hparams['in_size']}")
             old_model.in_size=model_hparams["in_size"]
         if old_model.target_sensor_name !=  model_config["gauge"]:
-            logging.info(msg=f"Updating {old_model} target_sensor_name to { model_config["gauge"]}")
+            logging.info(msg=f"Updating {old_model} target_sensor_name to {model_config['gauge']}")
             old_model.target_sensor_name= model_config["gauge"]
 
 
diff --git a/src/models/ensemble_models.py b/src/models/ensemble_models.py
index d8076fb86f6889fd7b2883c2d5f22b1fc6a2531c..f83c6845b7b80c8044732a8e017ae90213d85122 100644
--- a/src/models/ensemble_models.py
+++ b/src/models/ensemble_models.py
@@ -128,19 +128,18 @@ class WaVoLightningEnsemble(pl.LightningModule):
     """ Since the data for this is normalized using the scaler from the first model in model_list all models should be trained on the same or at least similar data i think?
     """
     # pylint: disable-next=unused-argument
-    def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate):
+    def __init__(self, model_list,model_path_list,hidden_size,num_layers,dropout,norm_func,learning_rate,return_base=False):
         super().__init__()
 
         self.model_list = model_list
         self.max_in_size = max([model.hparams['in_size'] for model in self.model_list])
-        self.in_size = self.max_in_size
         self.out_size = model_list[0].hparams['out_size']
         self.feature_count = model_list[0].hparams['feature_count']
-
-        assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx"
         self.target_idx = model_list[0].target_idx
+        self.scaler = model_list[0].scaler#TODO scaler richtig aus yaml laden -> als parameter übergeben, wie beim normalen modell.
+        assert len(set([model.target_idx for model in self.model_list])) == 1, "All models in the ensemble must have the same target_idx"
         self.model_architecture = 'ensemble'
-        self.scaler = model_list[0].scaler
+        self.return_base = return_base
 
         #self.weighting = WeightingModelAttention(
         self.weighting = WeightingModel(
@@ -155,11 +154,14 @@ class WaVoLightningEnsemble(pl.LightningModule):
 
         #TODO different insizes
         #TODO ModuleList?!
-        self.save_hyperparameters(ignore=['model_list','scaler'])
+        self.save_hyperparameters(ignore=['model_list','scaler','return_base'])
         self.save_hyperparameters({"scaler": pickle.dumps(self.scaler)})
         for model in self.model_list:
             model.freeze()
 
+    def set_return_base(self,return_base):
+        self.return_base = return_base
+
     def _common_step(self, batch, batch_idx,dataloader_idx=0):
         """
         Computes the weighted average of the predictions of the models in the ensemble.
@@ -185,7 +187,8 @@ class WaVoLightningEnsemble(pl.LightningModule):
         y_hats = torch.stack(y_hat_list,axis=1)
         w = self.weighting(x)
         y_hat = torch.sum((y_hats*w),axis=1)
-
+        if self.return_base:
+            return y_hat, y_hats
         return y_hat
 
 
diff --git a/src/models/lightning_module.py b/src/models/lightning_module.py
index bbecdafca1eb1f3502406eea6c3bd42187e2b16c..61e379112289c1f3cb960d149ed85235fa3729cd 100644
--- a/src/models/lightning_module.py
+++ b/src/models/lightning_module.py
@@ -69,7 +69,7 @@ class WaVoLightningModule(pl.LightningModule):
 
         self.save_hyperparameters(ignore=['scaler'])
         #self.save_hyperparameters({"scaler": pickle.dumps(self.scaler),})
-        
+
         if 'timestamp' in self.hparams: #TODO test if this works
             #print("already has timestamp")
             self.timestamp = self.hparams.timestamp
@@ -126,7 +126,6 @@ class WaVoLightningModule(pl.LightningModule):
             x = batch[0]
             #This is not great, but necessary if i want to use a scaler.
             pred = self.model(x)
-            #self.scaler.inverse_transform(x.reshape(-1, self.feature_count).cpu()).reshape(-1,self.in_size,self.feature_count)
             pred = ut.inv_standard(pred, self.scaler.mean_[self.target_idx], self.scaler.scale_[self.target_idx])
 
         elif self.model_architecture == "dlinear":
@@ -149,7 +148,7 @@ class WaVoLightningModule(pl.LightningModule):
         if self.differencing == 1:
             #This is only tested with the lstms
             x_org = self.scaler.inverse_transform(x.reshape(-1, self.feature_count).cpu()).reshape(-1,self.in_size,self.feature_count).astype('float32')
-                
+
             x_base = x_org[:,-1,self.gauge_idx].round(2) #round to get rid of floating point errors, the measure value is always an integer or has at most 1 decimal
             x_base = torch.from_numpy(x_base).unsqueeze(1)
             x_base = x_base.to(pred.device)
@@ -192,6 +191,9 @@ class WaVoLightningModule(pl.LightningModule):
             optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
         elif self.hparams.optimizer =='adamw':
             optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
+        else:
+            raise ValueError(f"Unknown optimizer {self.hparams.optimizer}")
+            return None
         return optimizer
 
     def get_network(self) -> nn.Module:
diff --git a/src/predict_database.py b/src/predict_database.py
index 5ef0d5b27fd49f6f0a940b98ca6b69da81a90983..34a6671d39fab14233aa7e9aced3e56cfcaa7550 100644
--- a/src/predict_database.py
+++ b/src/predict_database.py
@@ -11,8 +11,9 @@ import oracledb
 import pandas as pd
 import yaml
 
-from utils.db_tools.db_tools import OracleWaVoConnection
-from utils.db_tools.db_logging import OracleDBHandler
+import utils.db_tools.db_tools as dbt
+import utils.db_tools.db_logging as dbh
+
 
 
 def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]:
@@ -42,7 +43,7 @@ def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]:
         main_config["end"] = pd.to_datetime(main_config["end"])
 
     if main_config.get("fake_external"):
-        assert main_config["range"] == True, "fake_external only works with range=True"
+        assert main_config["range"], "fake_external only works with range=True"
 
     if passed_args.start is not None:
         main_config["start"] = passed_args.start
@@ -57,32 +58,13 @@ def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]:
         #assert all([isinstance(k,str) and isinstance(v,str) for k,v in gauge_config["external_fcst"].items()])
         assert all([isinstance(c,str) for c in gauge_config["external_fcst"]])
 
-    return main_config, gauge_configs
-
-
-def _prepare_logging_1() -> None:
-    logging.basicConfig(level=logging.INFO)
-    logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:])
+    ensemble_configs = configs.get("ensemble_configs", [])
+    for ensemble_config in ensemble_configs:
+        #ensemble_config["model_names"] = [Path(p) for p in ensemble_config['model_names']]
+        ensemble_config["ensemble_name"] = Path(ensemble_config['ensemble_name'])
+        ensemble_config["ensemble_model"] = ensemble_config.get('ensemble_model')
 
-
-def _prepare_logging_2(con) -> None:
-
-    old_factory = logging.getLogRecordFactory()
-    def record_factory(*args, **kwargs):
-        record = old_factory(*args, **kwargs)
-        record.gauge_id = kwargs.pop("gauge_id", None)
-        return record
-    logging.setLogRecordFactory(record_factory)
-
-    log_formatter = logging.Formatter(
-        "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S"
-    )
-    db_handler = OracleDBHandler(con)
-    db_handler.setFormatter(log_formatter)
-
-    db_handler.setLevel(logging.INFO)
-    logging.getLogger().addHandler(db_handler)
-    logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:])
+    return main_config, gauge_configs, ensemble_configs
 
 
 def parse_args() -> argparse.Namespace:
@@ -111,9 +93,9 @@ def parse_args() -> argparse.Namespace:
 
 def main(passed_args) -> None:
     """Main Function"""
-    main_config, gauge_configs = get_configs(passed_args=passed_args)
+    main_config, gauge_configs, ensemble_configs = get_configs(passed_args=passed_args)
 
-    _prepare_logging_1()
+    dbh.prepare_logging_1()
 
     if "config_dir" in main_config["db_params"]:
         logging.info(
@@ -125,7 +107,7 @@ def main(passed_args) -> None:
         logging.info("Initiating Thin mode")
 
     con = oracledb.connect(**main_config["db_params"])
-    _prepare_logging_2(con)
+    dbh.prepare_logging_2(con)
     # This is logged twice into Stout, because if thinks break it would not be logged at all.
     if "config_dir" in main_config["db_params"]:
         logging.info(
@@ -134,12 +116,15 @@ def main(passed_args) -> None:
         )
     else:
         logging.info("Initiating Thin mode")
-    connector = OracleWaVoConnection(con, main_config)
+    connector = dbt.OracleWaVoConnection(con, main_config)
     connector.maybe_update_tables()  # Potentially load new data
 
     for gauge_config in gauge_configs:
         connector.handle_gauge(gauge_config)
 
+    for ensemble_config in ensemble_configs:
+        connector.handle_ensemble(ensemble_config)
+
     logging.info("Finished %s with parameters %s ", sys.argv[0], sys.argv[1:])
 
 
diff --git a/src/train_db.py b/src/train_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b199fe0717c8ea1d43c41275a1d34d6a65311d1
--- /dev/null
+++ b/src/train_db.py
@@ -0,0 +1,274 @@
+"""
+Main Module to start training models using database data
+"""
+
+import argparse
+import logging
+from pathlib import Path
+import yaml
+
+import lightning.pytorch as pl
+import optuna
+import oracledb
+import pandas as pd
+import torch
+from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
+from lightning.pytorch.loggers import TensorBoardLogger
+
+
+from data_tools.data_module import DBWaVoDataModule
+from models.lightning_module import WaVoLightningModule
+from utils.callbacks import OptunaPruningCallback, WaVoCallback
+import utils.db_tools.db_logging as dbh
+import utils.db_tools.db_tools as dbt
+import utils.helpers as hp
+from sqlalchemy import create_engine
+from tbparse import SummaryReader
+
+
+
+
+
+
+class DBObjective:
+    """Objective class for Optuna optimization using database data"""
+    
+    def __init__(self, config, engine,log_dir):
+        self.config = config
+        self.engine = engine
+        self.log_dir = log_dir
+
+    def _get_callbacks(self, trial):
+        monitor = self.config['training']['monitor']
+        checkpoint_callback = ModelCheckpoint(
+            save_top_k=1,
+            monitor=monitor,
+            save_weights_only=True
+        )
+        pruning_callback = OptunaPruningCallback(trial, monitor=monitor)
+        early_stop_callback = EarlyStopping(
+            monitor=monitor,
+            mode="min",
+            patience=3
+        )
+        my_callback = WaVoCallback()
+        
+        return my_callback, [
+            checkpoint_callback,
+            pruning_callback,
+            early_stop_callback,
+            my_callback,
+        ]
+
+    def _suggest_value(self, trial, param_name, param_config):
+        """Helper method to suggest values based on parameter configuration"""
+        if param_config['type'] == 'int':
+            return trial.suggest_int(param_name, param_config['min'], param_config['max'])
+        elif param_config['type'] == 'float':
+            return trial.suggest_float(param_name, param_config['min'], param_config['max'])
+        elif param_config['type'] == 'categorical':
+            return trial.suggest_categorical(param_name, param_config['values'])
+        else:
+            raise ValueError(f"Unknown parameter type: {param_config['type']}")
+
+    def _get_model_params(self, trial):
+        """Get model parameters based on YAML configuration"""
+        hp_config = self.config['hyperparameters']
+        
+        # Get basic parameters
+        differencing = self._suggest_value(trial, "differencing", hp_config['differencing'])
+        learning_rate = self._suggest_value(trial, "lr", hp_config['learning_rate'])
+        optimizer = self._suggest_value(trial, "optimizer", hp_config['optimizer'])
+        model_architecture = self._suggest_value(trial, "model_architecture", hp_config['model_architecture'])
+        
+        # Get architecture-specific parameters
+        model_params = {}
+        if model_architecture in hp_config:
+            for param_name, param_config in hp_config[model_architecture].items():
+                model_params[param_name] = self._suggest_value(trial, param_name, param_config)
+            
+        return model_architecture, differencing, learning_rate, optimizer, model_params
+
+    def __call__(self, trial):
+        model_architecture, differencing, learning_rate, optimizer, model_params = self._get_model_params(trial)
+        
+        data_module = DBWaVoDataModule(
+            self.config["data"],
+            self.engine,
+            differencing=differencing,
+
+        )
+        
+        # Prepare model
+        model = WaVoLightningModule(
+            model_architecture=model_architecture,
+            feature_count=data_module.feature_count,
+            in_size=data_module.hparams.in_size,
+            out_size=data_module.hparams.out_size,
+            scaler=data_module.scaler,
+            target_idx=data_module.target_idx,
+            gauge_idx=data_module.gauge_idx,
+            learning_rate=learning_rate,
+            optimizer=optimizer,
+            differencing=data_module.hparams.differencing,
+            embed_time=data_module.hparams.embed_time,
+            **model_params)
+        
+        my_callback, callbacks = self._get_callbacks(trial)
+        logger = TensorBoardLogger(self.log_dir, default_hp_metric=False)
+        
+        trainer = pl.Trainer(
+            default_root_dir=self.log_dir,
+            logger=logger,
+            accelerator="cuda" if torch.cuda.is_available() else "cpu",
+            devices=1,
+            callbacks=callbacks,
+            max_epochs=self.config['training']['max_epochs'],
+            log_every_n_steps=self.config['training']['log_every_n_steps']
+        )
+        
+        trainer.fit(model, data_module)
+        
+        # Save metrics to optuna
+        model_path = str(Path(trainer.log_dir).resolve())
+        trial.set_user_attr("model_path", model_path)
+        trial.set_user_attr("external_fcst", data_module.external_fcst)
+        
+        for metric in ["hp/val_nse", "hp/val_mae", "hp/val_mae_flood"]:
+            for i in [23, 47]:
+                trial.set_user_attr(f'{metric}_{i+1}', my_callback.metrics[metric][i].item())
+
+
+        if self.config["general"]["save_all_to_db"]:
+            model_id = dbt.add_model(self.engine,self.config,trainer.log_dir)
+            if model_id != -1:
+                dbt.add_model_sensor(self.engine,model_id,data_module.sensors,data_module.external_fcst)
+                dbt.add_metrics(self.engine,model_id,my_callback.metrics,data_module.time_ranges)
+            #dbt.add_metrics(self.engine,self.config,trainer.log_dir,key,value)
+
+
+                
+        return my_callback.metrics[self.config['training']['monitor']].item()
+
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="Execute Hyperparameter optimization with database data using YAML configuration."
+    )
+    parser.add_argument(
+        "yaml_path",
+        type=Path,
+        help="Path to YAML configuration file"
+    )
+
+    return parser.parse_args()
+
+def get_configs(passed_args: argparse.Namespace) -> dict:
+    """
+    Retrieves the configurations from a YAML file.
+    Also converts the paths to Path objects and the dates to datetime objects.
+
+    Args:
+        passed_args (argparse.Namespace): The command-line arguments.
+
+    Returns:
+        Tuple[dict, List[dict]]: A tuple containing the main configuration dictionary and a list of gauge configuration dictionaries.
+    """
+    yaml_path = passed_args.yaml_path
+    with open(yaml_path, "r", encoding="utf-8") as file:
+        config = yaml.safe_load(file)
+    
+
+    config["general"]["log_dir"] = Path(config["general"]["log_dir"])
+    #TODO more typing stuff, assertions etc.
+
+    return config
+
+
+def main():
+    """Start hyperparameter optimization using database data."""
+    
+    args = parse_args()
+
+    config = get_configs(args)
+
+    config["general"]["log_dir"].mkdir(parents=True,exist_ok=True)
+
+    dbh.prepare_logging_1()
+    if "config_dir" in config["db_params"]:
+        logging.info("Initiating Thick mode with executable %s",config["db_params"]["config_dir"])
+        oracledb.init_oracle_client(lib_dir=config["db_params"]["config_dir"])
+    else:
+        logging.info("Initiating Thin mode")
+
+    con = oracledb.connect(**config["db_params"])
+    engine = create_engine("oracle+oracledb://", creator=lambda: con)
+
+    dbh.prepare_logging_2(con)
+
+
+    # Setup and run optimization
+    #study_name = f"{args.gauge}_{args.expname}"
+    #storage = optuna.storages.RDBStorage(url="oracle+oracledb://", engine_kwargs={'creator':lambda: con})
+    #storage = optuna.storages.RDBStorage(
+    #    url="sqlite:///:memory:", #TODO change to oracle
+    #    engine_kwargs={"pool_size": 20, "connect_args": {"timeout": 10}},
+    #)
+    storage = optuna.storages.RDBStorage(
+        url="sqlite:///:memory:",
+        #engine_kwargs={"pool_size": 20, "connect_args": {"timeout": 10}},
+    )
+    study_name = config["general"]["study_name"]
+    study = optuna.create_study(
+        study_name=study_name,
+        storage=storage,
+        direction="minimize",
+        load_if_exists=True
+    )
+    
+    study.set_metric_names(["hp/val_loss"])
+    
+    objective = DBObjective(
+        config=config,
+        engine=engine,
+        log_dir=config["general"]['log_dir']
+    )
+    
+    study.optimize(
+        objective,
+        n_trials=config["general"]["n_trials"],
+        gc_after_trial=True,
+        callbacks=[lambda study, trial: torch.cuda.empty_cache()]
+    )
+    print("a")
+    if config["general"]["save_best_to_db"]:
+        best_trial = study.best_trial
+        if config["general"]["save_all_to_db"]:
+            logging.info("Not saving best model %s to database, because all models are already saved.",best_trial.user_attrs["model_path"])
+            pass
+            #TODO mark the model somehow maybe?
+        else:
+            model_path = best_trial.user_attrs["model_path"]
+            external_fcst = best_trial.user_attrs["external_fcst"]
+            hparams = hp.load_settings_model(Path(model_path))
+            time_ranges =  {s: {a : pd.to_datetime(hparams[f"{s}_{a}"]) for a in ["start","end"]} for s in ["train","val","test"]}
+
+            model_id = dbt.add_model(engine,config,model_path)
+            if model_id != -1:
+                dbt.add_model_sensor(engine,model_id,hparams["scaler"].feature_names_in_,external_fcst)
+
+                reader = SummaryReader(model_path,pivot=True)
+                df =reader.scalars
+                loss_cols = df.columns[df.columns.str.contains("loss")]
+                df = df.drop(["step","epoch"] + list(loss_cols),axis=1)[1:49]
+                metrics = df.to_dict("list")
+                for k,v in metrics.items():
+                    metrics[k] = torch.tensor(v)
+
+                dbt.add_metrics(engine,model_id,metrics,time_ranges)
+
+    con.close()
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/src/utils/callbacks.py b/src/utils/callbacks.py
index 391fe2b0be8366e5b5a6b86a2c907c1a2cd8872a..c10525485d8038dfa0be7abd09a073d9757cec35 100644
--- a/src/utils/callbacks.py
+++ b/src/utils/callbacks.py
@@ -75,15 +75,20 @@ class WaVoCallback(Callback):
 
         #model = trainer.model
         #model.eval()
-        if isinstance(pl_module,WaVoLightningEnsemble):
-            checkpoint_path = next((Path(trainer.log_dir) / "checkpoints").iterdir())
-            model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=pl_module.model_list)    
+        if isinstance(pl_module,WaVoLightningEnsemble): #TODO testen
+            #checkpoint_path = next((Path(trainer.log_dir) / "checkpoints").iterdir())
+            #model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path,model_list=pl_module.model_list)
+            model, _ = hp.load_ensemble_model(Path(trainer.log_dir), pl_module.model_dir_list,dummy=False)
+
         elif isinstance(pl_module,WaVoLightningModule):
             model = hp.load_model(Path(trainer.log_dir))#TODO consider using trainer.predict ckpt_path param
         else:
             raise ValueError("Model is not of type WaVoLightningModule or WaVoLightningEnsemble")
+        
 
         model.eval()
+        if torch.cuda.is_available():
+            model.cuda()
 
         datamodule = trainer.datamodule
 
@@ -133,28 +138,31 @@ class WaVoCallback(Callback):
 
     @torch.no_grad()
     def _get_pred(self,model,data_loader,ensemble=False,desc=''):
-        #TODO hier doll aufräumen :(
+        #TODO hier doll aufräumen :( cuda macht vermutlich probleme wenn cuda nicht installier ist.
+        use_cuda = torch.cuda.is_available()
         if ensemble:
             y_pred_list = [[] for _ in range(len(model.model_list))]
             for sub_model in model.model_list:
-                sub_model.cuda()
+                if use_cuda:
+                   sub_model.cuda()
                 
             for i, batch in enumerate(tqdm(data_loader,desc=f"ens {desc}",total=len(data_loader))):
-                batch = _repack_batch(batch)
+                batch = tuple(map(lambda x: x.cuda(),batch)) if use_cuda else batch
                 for j,sub_model in enumerate(model.model_list):
                     y_pred_list[j].append(sub_model.predict_step(batch,batch_idx=i))
 
             y_pred = torch.stack([torch.concat(y_pred) for y_pred in y_pred_list]).mean(dim=0)
         else:
-            model.cuda()
+            #if use_cuda:
+            #    model.cuda()
             y_pred_list = []
             for i, batch in enumerate(tqdm(data_loader,desc=desc,total=len(data_loader))):
-                batch = _repack_batch(batch)
+                batch = tuple(map(lambda x: x.cuda(),batch)) if use_cuda else batch
                 y_pred_list.append(model.predict_step(batch,batch_idx=i))
 
             y_pred = torch.concat(y_pred_list)
 
-        return y_pred.cuda()
+        return y_pred.to(model.device,torch.float32)
 
     def _get_loader_and_y_true(self,data_module:pl.LightningDataModule,model:pl.LightningModule,mode :str)-> tuple[DataLoader, torch.Tensor]:
         """Return the necessary data_loader and the ground truth. #TODO comments
@@ -182,7 +190,7 @@ class WaVoCallback(Callback):
         else:
             raise ValueError(f"Mode {mode} not recognized")
 
-        if ut.need_classic_input(data_module.model_architecture):
+        if ut.need_classic_input(model.model_architecture):
             y_true = torch.stack([y for _,y in dataset])
         else:
             y_true = torch.stack([y for _,y,_,_ in dataset])
@@ -194,25 +202,11 @@ class WaVoCallback(Callback):
             x_org = data_module.scaler.inverse_transform(x_true.reshape(-1,data_module.feature_count)).reshape(-1,data_module.in_size,data_module.feature_count)
             x_base = x_org[:,-1,data_module.gauge_idx].round(2)
             y_true = torch.from_numpy(x_base).unsqueeze(1)  + y_true.cumsum(1)
-            
-        return data_loader, y_true.cuda()
-
-
-def _repack_batch(batch):
-    if len(batch) == 2:
-        x, y = batch
-        x = x.cuda()
-        y = y.cuda()
-        batch = (x,y)
-    else:
-        x, y, x_time, y_time = batch
-        x = x.cuda()
-        y = y.cuda()
-        x_time = x_time.cuda()
-        y_time = y_time.cuda()
-        batch = (x,y,x_time,y_time)
-
-    return batch
+        
+        return data_loader, y_true.to(model.device,torch.float32)
+
+
+
     
 class OptunaPruningCallback(PyTorchLightningPruningCallback, Callback):
     """Custom optuna Pruning Callback, because CUDA/Lightning do not play well with the default one.
diff --git a/src/utils/db_tools/db_logging.py b/src/utils/db_tools/db_logging.py
index 66e15c7e9a50bde1e4809d4efa109400bae23c41..fdd07fcf38cd862e2bdb9a6a3b9ad17e7607412c 100644
--- a/src/utils/db_tools/db_logging.py
+++ b/src/utils/db_tools/db_logging.py
@@ -1,4 +1,5 @@
 import logging
+import sys
 from utils.db_tools.orm_classes import Log
 from sqlalchemy import create_engine
 from sqlalchemy.orm import Session
@@ -26,7 +27,7 @@ class OracleDBHandler(logging.Handler):
             query = """CREATE TABLE LOG(
                         ID int GENERATED BY DEFAULT AS IDENTITY (START WITH 1),
                         Created TIMESTAMP,
-                        LogLevelName varchar(32),    
+                        LogLevelName varchar(32),
                         Message varchar(256),
                         Module varchar(64),
                         FuncName varchar(64),
@@ -54,4 +55,29 @@ class OracleDBHandler(logging.Handler):
             gauge=record.gauge_id,
         )
         self.session.add(log)
-        self.session.commit()
\ No newline at end of file
+        self.session.commit()
+
+
+def prepare_logging_1() -> None:
+    logging.basicConfig(level=logging.INFO)
+    logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:])
+
+
+def prepare_logging_2(con) -> None:
+
+    old_factory = logging.getLogRecordFactory()
+    def record_factory(*args, **kwargs):
+        record = old_factory(*args, **kwargs)
+        record.gauge_id = kwargs.pop("gauge_id", None)
+        return record
+    logging.setLogRecordFactory(record_factory)
+
+    log_formatter = logging.Formatter(
+        "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S"
+    )
+    db_handler = OracleDBHandler(con)
+    db_handler.setFormatter(log_formatter)
+
+    db_handler.setLevel(logging.INFO)
+    logging.getLogger().addHandler(db_handler)
+    logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:])
diff --git a/src/utils/db_tools/db_tools.py b/src/utils/db_tools/db_tools.py
index 310e857f7034d01e2250513821f41e07f34c1e1e..c4eb34a28bcf23bc042ad9716f4ab9e5eb57fe66 100644
--- a/src/utils/db_tools/db_tools.py
+++ b/src/utils/db_tools/db_tools.py
@@ -14,10 +14,10 @@ from functools import reduce
 import numpy as np
 import oracledb
 import pandas as pd
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, func
 from sqlalchemy import select, bindparam, between, update
 from sqlalchemy.orm import Session
-from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2
+#from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2
 from sqlalchemy.exc import IntegrityError
 import torch
 
@@ -28,7 +28,7 @@ from utils.db_tools.orm_classes import (
     SensorData,
     InputForecasts,
     PegelForecasts,
-    Log,
+    Metric,
     Sensor,
     InputForecastsMeta,
     Modell
@@ -124,34 +124,31 @@ class OracleWaVoConnection:
         model = hp.load_model(gauge_config["model_folder"])
         model.eval()
 
-        created = datetime.now(timezone)
+        created = datetime.now(timezone.utc)
         #get the metadata for the external forecasts (to access them later from the db)
         input_meta_data = self._ext_fcst_metadata(gauge_config)
 
         if len(self.times)> 1: 
-            df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created)
+            df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created,series=True)
+            df_ext = self.load_input_ext_fcst(gauge_config["external_fcst"]) #load external forecasts
 
-            #load external forecasts
-            if self.main_config.get("fake_external"):
-                df_temp = None
-            else:
-                stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(gauge_config["external_fcst"]),
-                                                    between(InputForecasts.tstamp, self.times[0], self.times[-1]),
-                                                    InputForecasts.member.in_(self.members))
-
-                df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
                     
-            dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp)
+            dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_ext,model.differencing)
             df_pred = pred_multiple_db(model, dataset)
             df_pred["created"] = created
             df_pred["sensor_name"] = gauge_config["gauge"]
             df_pred["model"] = gauge_config["model_folder"].name
 
+
+            with Session(self.engine) as session:
+                sensor_data_objs = df_pred_create.apply(lambda row: SensorData(**row.to_dict()), axis=1).tolist()
+                session.add_all(sensor_data_objs)
+                session.commit()
+
             stmt =select(PegelForecasts.tstamp,PegelForecasts.member).where(
                 PegelForecasts.model==gauge_config["model_folder"].name,
                 PegelForecasts.sensor_name==gauge_config["gauge"],
                 PegelForecasts.tstamp.in_(df_pred["tstamp"]),
-                #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()),
                 )
                                                                                                                             
             df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
@@ -231,44 +228,145 @@ class OracleWaVoConnection:
             None
         """
         try:
-            if member == -1:
-                df_input = df_base_input.fillna(0)
-            else:
-                # replace fake forecast with external forecast
-                df_input = df_base_input.copy()
-                for input_meta in input_meta_data:
-                    if input_meta.ensemble_members == 1:
-                        temp_member = 0
-                    elif input_meta.ensemble_members == 21:
-                        temp_member = member
-                    else:
-                        temp_member = member #TODO was hier tun?
-
-                    fcst_data = df_temp[(df_temp["member"]== temp_member) & (df_temp["sensor_name"] == input_meta.sensor_name)]
-                    if len(fcst_data) == 0:
-                        raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}")
-
-                    # Replace missing values with forecasts.
-                    # It is guaranteed that the only missing values are in the "future" (after end_time)
-                    assert df_input[:-48].isna().sum().sum() == 0
-                    nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48)
-                    nan_indices = nan_indices[nan_indices >= 0]
-                    nan_indices2 = nan_indices + (144 - 48)
-                    col_idx = df_input.columns.get_loc(input_meta.sensor_name)
-                    df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices]
-
+            df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_temp)
         except LookupError as e:
             logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]})
             y = torch.Tensor(48).fill_(np.nan)
         else:
-            y = pred_single_db(model, df_input)
+            y = pred_single_db(model, df_input,differencing=model.differencing)
         finally:
             self.insert_forecast(
                 y,gauge_config,end_time, member, created
             )
 
+    def handle_ensemble(self, ensemble_config: dict) -> None:
+        """
+        Makes ensemble forecasts and inserts them into the database.
+        If no ensemble model is specified in the ensemble_config, the mean of the members is used.
+        Args:
+            ensemble_config (dict): The configuration for the ensemble, model_folder, ensemble_name, ensemble_model.
+        """
+        #TODO add logging here
+        #TODO zrxp exporte?
+        created = datetime.now(timezone.utc)
+        # Collect the base models
+        stmt = select(Modell).where(Modell.modellname.in_(bindparam("model_names",expanding=True)))
+        df_models = pd.read_sql(sql=stmt,con=self.engine,params={"model_names": ensemble_config["model_names"]})
+        
+        pseudo_model_names = []
+        model_list = []
+        for model_name in ensemble_config["model_names"]:
+            model_path = Path(df_models[df_models["modellname"] == model_name]["modelldatei"].values[0])
+            model_list.append(model_path)
+            pseudo_model_names.append(model_path.name)
+
+        fake_config = self._generate_fake_config(df_models,ensemble_config)
+        input_meta_data = self._ext_fcst_metadata(fake_config)
+        df_ext = self.load_input_ext_fcst(fake_config["external_fcst"]) #load external forecasts
+
+        #load ensemble model, maybe with dummys
+        if ensemble_config["mode"] != 'naiv':
+            model, ensemble_yaml = hp.load_ensemble_model(ensemble_config["ensemble_model"],model_list,dummy=ensemble_config["mode"] == 'only_ensemble')
+            model.set_return_base(ensemble_config["mode"] == 'save_both')
+            differencing = ensemble_yaml["differencing"]
+
+        
+        #save_both -> calculate the base models and the ensemble model and save both
+        #calc_both> calculate the base models and the ensemble model but only save the ensemble model
+
+        #Warning: This behaves slightly differently from the other versions.
+        #Here the external forecast only overwrite their corresponding sensor values if they are nan.
+        #This means the model has more information for older timestamps.
+        if ensemble_config["mode"] == 'only_ensemble' or ensemble_config["mode"] == 'naiv':
+            stmt = select(PegelForecasts).where(
+                PegelForecasts.tstamp.in_(bindparam("times",expanding=True)),
+                PegelForecasts.model.in_(bindparam("pseudo_model_names", expanding=True)))
+            df_ens = pd.read_sql(sql=stmt,con=self.engine,index_col="tstamp",
+                params={"times": self.times,"pseudo_model_names": pseudo_model_names})
+
+            #self.insert_forecast(y,fake_config,cur_time, member, created)
+            #TODO this could be much more efficient if i collected the forecasts i think
+            if ensemble_config["mode"] == 'naiv': # -> load the results of the base models from the database and take the naive mean
+                for cur_time in df_ens.index.unique():
+                    df_y_time_slice = df_ens.loc[cur_time]
+                    for member in self.members:
+                        df_y_member = df_y_time_slice[df_y_time_slice["member"] == member]
+                        y = df_y_member.iloc[:,4:].mean()
+                        y = torch.Tensor(y)
+                        self.insert_forecast(y,fake_config,cur_time, member, created)
+
+            if ensemble_config["mode"] == 'only_ensemble': # -> load the results of the base models from the database
+                for cur_time in df_ens.index.unique():
+                    df_y_time_slice = df_ens.loc[cur_time]
+                    df_base_input = self.load_input_db(model.max_in_size,cur_time,fake_config,created)
+                    for member in self.members:
+                        df_y_member = df_y_time_slice[df_y_time_slice["member"] == member]
+                        y_ens = df_y_member.iloc[:,4:].values
+                        try:
+                            x = self._get_x(df_base_input,member,input_meta_data,df_ext,differencing,model)
+                        except LookupError as e:
+                            logging.error(e.args[0],extra={"gauge":fake_config["gauge"]})
+                            y = torch.Tensor(48).fill_(np.nan)
+                        else:
+                            w = model.weighting(x)
+                            y = torch.sum((torch.tensor(y_ens)*w),axis=1)[0]
+                        finally:
+                            self.insert_forecast(y,fake_config,cur_time, member, created)
+            
+        elif ensemble_config["mode"] == 'save_both' or ensemble_config["mode"] == 'calc_both':
+            #df_base_input_all = self.load_input_db(model.max_in_size,self.times[-1],fake_config,created)
+
+            df_base_input = self.load_input_db(model.max_in_size, self.times[-1], fake_config,created,series=True)
+            ds = self._make_dataset(fake_config,df_base_input, model,input_meta_data,df_ext,ensemble_yaml["differencing"])
+            #df_pred = pred_multiple_db(model, dataset)
+            data_loader = DataLoader(ds, batch_size=256,num_workers=1)
+            pred = hp.get_pred(model, data_loader,None)
+
+            df_base_preds = []                
+            df_template = ds.get_whole_index()
+            if ensemble_config["mode"] == 'save_both':
+                pred, base_preds = pred
+                for j in range(base_preds.shape[1]):
+                    df_temp = pd.concat([df_template, pd.DataFrame(base_preds[:,j,:],columns=[f"h{i}" for i in range(1,49)])], axis=1)
+                    df_temp["model"] = model_list[j].name
+                    df_base_preds.append(df_temp)
+                    
+            df_temp = pd.concat([df_template, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1)
+            df_temp["model"] = fake_config["model_folder"].name
+            df_base_preds.append(df_temp)
+
+            df_pred = pd.concat(df_base_preds)
+            df_pred["created"] = created
+            df_pred["sensor_name"] = fake_config["gauge"]
+
 
-    def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame:
+            #pseudo_model_names
+            stmt =select(PegelForecasts.tstamp,PegelForecasts.member,PegelForecasts.model).where(
+                #PegelForecasts.model==fake_config["model_folder"].name,
+                #PegelForecasts.model.in_(pseudo_model_names + [fake_config["model_folder"].name]),
+                PegelForecasts.model.in_(df_pred["model"].unique()),
+                PegelForecasts.sensor_name==fake_config["gauge"],
+                PegelForecasts.tstamp.in_(df_pred["tstamp"].unique()),
+                )
+                                                                                                                            
+            df_already_inserted = pd.read_sql(sql=stmt,con=self.engine)
+            merged = df_pred.merge(df_already_inserted[['tstamp', 'member','model']],on=['tstamp', 'member','model'], how='left',indicator=True)
+            df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1)
+            df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1)
+            #sqlalchemy.orm.exc.StaleDataError: UPDATE statement on table 'pegel_forecasts' expected to update 13860 row(s); 12600 were matched.
+            with Session(self.engine) as session:
+                forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist()
+                session.add_all(forecast_objs)
+                session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records'))
+                session.commit()
+
+
+
+
+
+
+
+    def load_input_db(self, in_size, end_time, gauge_config,created,series=False) -> pd.DataFrame:
         """
         Loads input data from the database based on the current configuration.
 
@@ -277,6 +375,7 @@ class OracleWaVoConnection:
             end_time (pd.Timestamp): The end time of the input data.
             gauge_config (dict): The configuration for the gauge, important here are the columns and external forecasts.
             created (pd.Timestamp): The timestamp of the start of the forecast creation.
+            series (bool): If True, the input data is loaded all at once, this doesn't work with the standard way, because the forecasts would be available to models that shouldn have them yet.
 
         Returns:
             pd.DataFrame: The loaded input data as a pandas DataFrame.
@@ -285,7 +384,7 @@ class OracleWaVoConnection:
         columns=gauge_config["columns"]
         external_fcst=gauge_config["external_fcst"]
         
-        if len(self.times) > 1:
+        if len(self.times) > 1 and series:
             first_end_time = self.times[0]
             start_time = first_end_time - pd.Timedelta(in_size - 1, "hours")
             if self.main_config.get("fake_external"):
@@ -388,6 +487,62 @@ class OracleWaVoConnection:
 
         return df_input
 
+    def load_input_ext_fcst(self,external_fsct = None) -> pd.DataFrame:
+        if self.main_config.get("fake_external"):
+            return None
+        
+        if external_fsct is None:
+            external_fsct =  []
+
+        stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(external_fsct),
+                                            between(InputForecasts.tstamp, self.times[0], self.times[-1]),
+                                            InputForecasts.member.in_(self.members))
+
+        df_ext = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp")
+        return df_ext
+
+    def fill_base_input(self,member,df_base_input,input_meta_data,df_ext) -> pd.DataFrame:
+        """This function fills the base input data with external forecasts depending on the current (available) member.
+
+        Args:
+            member (_type_): member identifier (-1-21)
+            df_base_input (_type_): sensor data from the database
+            input_meta_data (_type_): metadata for the external forecasts
+            df_ext (_type_): external forecast data
+
+        Returns:
+            pd.DataFrame: Input data with external forecasts inserted
+        """
+        if member == -1:
+            df_input = df_base_input.fillna(0)
+        else:
+            # replace fake forecast with external forecast
+            df_input = df_base_input.copy()
+            for input_meta in input_meta_data:
+                if input_meta.ensemble_members == 1:
+                    temp_member = 0
+                elif input_meta.ensemble_members == 21:
+                    temp_member = member
+                else:
+                    temp_member = member #TODO was hier tun?
+
+                fcst_data = df_ext[(df_ext["member"]== temp_member) & (df_ext["sensor_name"] == input_meta.sensor_name)]
+                if len(fcst_data) == 0:
+                    raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}")
+
+                # Replace missing values with forecasts.
+                # It is guaranteed that the only missing values are in the "future" (after end_time)
+                assert df_input[:-48].isna().sum().sum() == 0
+                nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48)
+                nan_indices = nan_indices[nan_indices >= 0]
+                nan_indices2 = nan_indices + (144 - 48)
+                col_idx = df_input.columns.get_loc(input_meta.sensor_name)
+                df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices]
+                
+        return df_input
+
+
+
 
     def add_zrxp_data(self, zrxp_folders: List[Path]) -> None:
         """Adds zrxp data to the database
@@ -549,7 +704,7 @@ class OracleWaVoConnection:
             fcst_values = {f"h{i}": None for i in range(1, 49)}
         else:
             fcst_values = {f"h{i}": forecast[i - 1].item() for i in range(1, 49)}
-            self._export_zrxp(forecast, gauge_config, end_time, member, sensor_name, model_name, created)
+            self._export_zrxp(forecast, gauge_config, end_time, member, sensor_name, model_name)
 
         stmt = select(PegelForecasts).where(
             PegelForecasts.tstamp == bindparam("tstamp"),
@@ -596,8 +751,34 @@ class OracleWaVoConnection:
             #try:
             session.commit()
 
+    def _generate_fake_config(self,df_models,ensemble_config):
+        model_id = df_models.iloc[0,0].item()
+
+        query = (
+            select(ModellSensor)
+            .where(ModellSensor.modell_id == model_id)
+            .order_by(ModellSensor.ix)
+        )
+        df_ms = pd.read_sql(query,self.engine)
+
+        fake_config = {'gauge': df_models["target_sensor_name"].values[0],
+                    'model_folder' : ensemble_config['ensemble_name'],
+                    'external_fcst' : list(df_ms.dropna()["sensor_name"]), #drop the rows without a value in vhs_gebiet,
+                    'columns' : list(df_ms["sensor_name"])}
+        return fake_config
+    
+    def _get_x(self,df_base_input,member,input_meta_data,df_ext,differencing,model): 
+        df_input = self.fill_base_input(member,df_base_input,input_meta_data,df_ext)
+        with warnings.catch_warnings():
+            warnings.filterwarnings(action="ignore", category=UserWarning)
+            if differencing == 1:
+                df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
+            x = model.scaler.transform(df_input.values)
+        x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0)
+        return x
+
 
-    def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created):
+    def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name):
         if self.main_config.get("export_zrxp") and member == 0:
 
             try:
@@ -640,10 +821,10 @@ class OracleWaVoConnection:
 
 
 
-    def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext):
+    def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext,differencing):
         with warnings.catch_warnings():
             warnings.filterwarnings(action="ignore", category=UserWarning)
-            if model.differencing == 1:
+            if differencing == 1:
                 df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
             #TODO time embedding stuff
             df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns)
@@ -657,20 +838,41 @@ class OracleWaVoConnection:
         ds = DBDataSet(df_scaled,model.in_size,input_meta_data,df_ext)
         return ds
 
-        
+    def manually_add_sensor_data(self,data_file: Path,sensor_names=None):
+        """This function isn't called anywhere, but it could be used manually to add a csv table of sensordata to the database
+
+        Args:
+            data_file (Path|str): path to csv file
+            sensor_names (list[str], optional): List of column names, if the columns in the csv file don't match the database names. Defaults to None.
+        """
+        print("WARNING: Do not use dataframes with 'shifted' columns, so columns where a forecast is fakes by shifting the values by 48 hours")
+        #df = pd.read_csv("../data/input/2025/mergedWillenscharenAll6.csv",index_col=0,parse_dates=True)
+        #df.columns = model.scaler.feature_names_in_[:-1]
+        df = pd.read_csv(data_file,index_col=0,parse_dates=True)
+        if sensor_names is not None:
+            df.columns = sensor_names
+
+        df = pd.melt(df,var_name="sensor_name",value_name="sensor_value",ignore_index=False)
+        stmt = select(SensorData).where(
+            between(SensorData.tstamp, df.index.min(), df.index.max()),
+            SensorData.sensor_name.in_(df["sensor_name"].unique()))
+        df_already_inserted = pd.read_sql(stmt,self.engine)
+        df = df.reset_index()
+        merged = df.merge(df_already_inserted[['tstamp', 'sensor_name']],on=['tstamp', 'sensor_name'], how='left',indicator=True)
+        df_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1)
+        df_create = df_create.dropna()
+        with Session(self.engine) as session:
+            sensor_data_objs = df_create.apply(lambda row: SensorData(**row.to_dict()), axis=1).tolist()
+            session.add_all(sensor_data_objs)
+            session.commit()
+
     def maybe_update_tables(self) -> None:
         """
         Updates the database tables if necessary based on the configuration settings.
 
-        If the 'load_sensor' flag is set to True in the main configuration, it adds sensor data
-        from the specified sensor_folder to the database.
-
         If the 'load_zrxp' flag is set to True in the main configuration, it adds zrxp data
         from the specified zrxp_folder to the database.
         """
-        if self.main_config.get("load_sensor"):
-            self.add_sensor_data(self.main_config["sensor_folder"], force=False)
-
         if self.main_config.get("load_zrxp"):
             self.add_zrxp_data(self.main_config["zrxp_folder"])
 
@@ -726,106 +928,24 @@ class OracleWaVoConnection:
 
         return input_meta
 
-    # @deprecated("This method is deprecated, data can be loaded directly from wiski (Not part of this package).")
-    def add_sensor_data(self, sensor_folder, force=False) -> None:
-        """
-        Adds sensor data to the database.
-        Each csv file in the sensor folder is read and resampled to hourly values, melted and written to a tmp table.
-        The tmp table is then merged into the sensor_data table.
-        NULL values are always replaced, other values only if force is True.
-        Columns are ["tstamp", "sensor_name", "sensor_value"]
-
-        Args:
-            sensor_folder (str): The path to the folder containing the sensor data files.
-            force (bool, optional): If True, forces the update of existing sensor data. Defaults to False.
-        """
-        # TODO this method isn't used anymore
-        sensor_folder = Path(sensor_folder)
-        if sensor_folder.is_dir():
-            sensor_files = list(sensor_folder.iterdir())
-        elif sensor_folder.is_file():
-            sensor_files = [sensor_folder]
-
-        for sensor_file in filter(
-            lambda x: x.is_file() and x.suffix == ".csv", sensor_files
-        ):
-            df = pd.read_csv(sensor_file, index_col=0, parse_dates=True)
-            df = df.resample("h").mean()
-            df["tstamp"] = df.index
-
-            # get all columns with precipitation and fill missing values with 0
-            mask = df.columns.str.contains("NEW") | df.columns.str.contains("NVh")
-            na_count = df.isna().sum(axis=0)
-            prec_cols = list(na_count[mask][na_count[mask] > 0].index)
-            if len(prec_cols) > 0:
-                df.loc[:, mask] = df.loc[:, mask].fillna(0)
-            # interpolate data in all other columns
-            df = df.interpolate(limit=8, limit_direction="both")  # TODO limit parameter
-
-            df = df.replace({np.nan: None})
-
-            df_long = df.melt(id_vars="tstamp")
-            df_long.columns = ["tstamp", "sensor_name", "sensor_value"]
-
-            # lower case, because of something like this: https://github.com/SAP/sqlalchemy-hana/issues/22
-            df_long.to_sql(
-                "tmp_sensor_data",
-                con=self.engine,
-                if_exists="replace",
-                index=False,
-                dtype={
-                    "tstamp": TIMESTAMP(),
-                    "sensor_name": VARCHAR2(128),
-                    "sensor_value": FLOAT(),
-                },
-            )
-
-            # Merge TMP table into sensor_data table
-            with self.con.cursor() as cursor:
-                # This feels ineffiecient, if you know more about SQL redo it.
-                if force:
-                    merge_sensor_query = """MERGE INTO sensor_data
-                        USING tmp_sensor_data 
-                        ON (sensor_data.tstamp = TMP_SENSOR_DATA.tstamp AND sensor_data.sensor_name = TMP_SENSOR_DATA.sensor_name)
-                            WHEN MATCHED THEN
-                                UPDATE SET 
-                                    sensor_data.sensor_value = TMP_SENSOR_DATA.sensor_value
-                                WHERE
-                                    sensor_data.sensor_value <> TMP_SENSOR_DATA.sensor_value
-
-                            WHEN NOT MATCHED THEN
-                                INSERT values (TMP_SENSOR_DATA.tstamp,TMP_SENSOR_DATA.sensor_name,TMP_SENSOR_DATA.sensor_value)"""
-                else:
-                    merge_sensor_query = """MERGE INTO sensor_data
-                        USING tmp_sensor_data 
-                        ON (sensor_data.tstamp = TMP_SENSOR_DATA.tstamp AND sensor_data.sensor_name = TMP_SENSOR_DATA.sensor_name)
-                            WHEN MATCHED THEN
-                                UPDATE SET 
-                                    sensor_data.sensor_value = COALESCE(sensor_data.sensor_value,TMP_SENSOR_DATA.sensor_value)
-                                WHERE
-                                    sensor_data.sensor_value <> TMP_SENSOR_DATA.sensor_value
-                            WHEN NOT MATCHED THEN
-                                INSERT values (TMP_SENSOR_DATA.tstamp,TMP_SENSOR_DATA.sensor_name,TMP_SENSOR_DATA.sensor_value)"""
-
-                cursor.execute(merge_sensor_query)
-                self.con.commit()
 
 
 @torch.no_grad()
-def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor:
+def pred_single_db(model, df_input: pd.DataFrame,differencing=True) -> torch.Tensor:
     """
     Predicts the output for a single input data point.
 
     Args:
         model (LightningModule): The model to use for prediction.
         df_input (pandas.DataFrame): Input data point as a DataFrame.
+        differencing (bool): If True a column is inserted at the last place with differencing of the target value
 
     Returns:
         torch.Tensor: Predicted output as a tensor.
     """
     with warnings.catch_warnings():
         warnings.filterwarnings(action="ignore", category=UserWarning)
-        if model.differencing == 1:
+        if differencing == 1:
             df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
         #TODO time embedding stuff
         x = model.scaler.transform(
@@ -860,40 +980,82 @@ def pred_multiple_db(model, ds: Dataset) -> torch.Tensor:
     return df_pred
 
 
-def get_merge_forecast_query(out_size: int = 48) -> str:
-    """Returns a merge query for the PEGEL_FORECASTS table with the correct number of forecast hours."""
-
-    merge_forecast_query = """MERGE INTO PEGEL_FORECASTS
-        USING TMP_PEGEL_FORECASTS 
-        ON (
-            PEGEL_FORECASTS.tstamp = TMP_PEGEL_FORECASTS.tstamp AND 
-            PEGEL_FORECASTS.sensor_name = TMP_PEGEL_FORECASTS.sensor_name AND 
-            PEGEL_FORECASTS.model = TMP_PEGEL_FORECASTS.model AND 
-            PEGEL_FORECASTS.member = TMP_PEGEL_FORECASTS.member)
-        WHEN MATCHED THEN
-            UPDATE SET
-                PEGEL_FORECASTS.created = TMP_PEGEL_FORECASTS.created, """
-
-    merge_forecast_query += (
-        ",".join(
-            [
-                f"\n PEGEL_FORECASTS.h{i} = TMP_PEGEL_FORECASTS.h{i}"
-                for i in range(1, out_size + 1)
-            ]
+
+def add_model(engine,config,model_path,in_size=144) -> int:
+    if isinstance(model_path,str):
+        model_path = Path(model_path)
+        model_path = model_path.resolve()
+
+    mstnr = config["data"]["target_column"].split(',')[0]
+    target_sensor_name = config["data"]["target_column"]
+    model_prefix = config['general']['model_name_prefix']
+    model_name = f"{model_prefix}_{model_path.name}"
+
+    with Session(engine) as session:
+        old_model = session.scalars(select(Modell).where(Modell.modellname == model_name)).first()
+
+        if old_model is not None:
+            logging.info("model %s couldn't be added, model with this name already exists",model_name)
+            return -1
+
+        model_id = session.execute(select(func.max(Modell.id))).scalar() or 0
+        model_id += 1
+        model_obj = Modell(id=model_id,
+            mstnr=mstnr,
+            modellname=model_name,
+            modelldatei=str(model_path),
+            aktiv=0,
+            kommentar=' \n'.join(config["data"]["sensors"]),
+            in_size=in_size,
+            target_sensor_name=target_sensor_name,
         )
-        + "\n"
-    )
+        session.add(model_obj)
+        session.commit()
+        logging.info("Added model %s to database",model_obj)
+    return model_obj.id
 
-    merge_forecast_query += """WHEN NOT MATCHED THEN
-            INSERT values(TMP_PEGEL_FORECASTS.tstamp,TMP_PEGEL_FORECASTS.sensor_name,TMP_PEGEL_FORECASTS.model,TMP_PEGEL_FORECASTS.member,TMP_PEGEL_FORECASTS.created,"""
+def add_model_sensor(engine,model_id,sensor_names,external_fcst) -> None:
+    with Session(engine) as session:
 
-    merge_forecast_query += (
-        ",".join([f"\n TMP_PEGEL_FORECASTS.h{i}" for i in range(1, out_size + 1)]) + ")"
-    )
 
-    return merge_forecast_query
+        for i,col in enumerate(sensor_names):
+            if col != "d1":
+                if col in external_fcst:
+                    vhs_gebiet = session.scalars(select(InputForecastsMeta.vhs_gebiet).where(InputForecastsMeta.sensor_name == col)).first()
+                    if vhs_gebiet is None:
+                        raise LookupError(f"Sensor {col} not found in INPUT_FORECASTS_META")
+                else:
+                    vhs_gebiet = None
 
+            model_sensor = ModellSensor(
+                modell_id = model_id,
+                sensor_name = col,
+                vhs_gebiet = vhs_gebiet,
+                ix = i
+            )
+            session.add(model_sensor)
+        session.commit()
+        
+def add_metrics(engine,model_id,metrics,time_ranges) -> None:
+    metric_list = []
+    for key,value in metrics.items():
+        key = key.replace("hp/","")
+        set_name = key.split("_")[0]
+        tag = "flood" if key.endswith("_flood") else "normal"
+        key = key.replace("_flood","")
+        start_time = time_ranges[set_name]["start"]
+        end_time = time_ranges[set_name]["end"]
+
+        if len(value.shape) > 0 and len(value) == 48:
+            for i,row in enumerate(value):
+                metric = Metric(model_id=model_id,metric_name=key,timestep=i+1,start_time=start_time,end_time=end_time,value=row,tag=tag)
+                metric_list.append(metric)
+
+    with Session(engine) as session:
+        session.add_all(metric_list)
+        session.commit()
 
+                
 
 
 class MissingTimestampsError(LookupError):
diff --git a/src/utils/helpers.py b/src/utils/helpers.py
index ce591b6212f15526f1ab9b144d23afc0bbd349d9..4a44f3594f1cdce01ec82a5787207aa2d3d933f3 100644
--- a/src/utils/helpers.py
+++ b/src/utils/helpers.py
@@ -16,6 +16,7 @@ from torch.utils.data import DataLoader
 # from data_tools.data_module import WaVoDataModule
 from data_tools.data_module import WaVoDataModule
 from models.lightning_module import WaVoLightningModule
+from models.ensemble_models import WaVoLightningEnsemble
 
 
 def load_settings_model(model_dir: Path) -> dict:
@@ -90,11 +91,17 @@ def get_pred(
     if move_model:
         model.cuda()
 
-    y_pred = np.concatenate(pred)
+    if getattr(model, 'return_base', None):
+        pred, base_preds = pred[0] #TODO this function is not tested with more than one batch 
+        return pred, base_preds
 
-    if y_true is not None:
-        y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred?
-        y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1]))
+    else:
+            
+        y_pred = np.concatenate(pred)
+
+        if y_true is not None:
+            y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred?
+            y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1]))
     return y_pred
 
 
@@ -211,3 +218,39 @@ def load_modules_rl(
     data_module.setup(stage="")
 
     return model_list, data_module
+
+
+def load_ensemble_model(model_dir, model_dir_list,dummy=False):
+    """Loads a WaVoLightningEnsemble from the given folder, also need a list of the models that are part of the ensemble.
+    Models can be loaded from the model_dir_list or the models can be passed as WaVoLightningModules.
+
+    Args:
+        model_dir : Path to the folder containing the ensemble model
+        model_dir_list : List of paths to the models that are part of the ensemble. Can also be a list of WaVoLightningModules, assumed to be frozen.
+        dummy (bool, optional): If True only the first base_model in model_dir_list is actually loaded. Defaults to False.
+
+    Returns:
+        _type_: model and ensemble_yaml with infos about that model
+    """
+    model_dir = Path(model_dir)
+    ensemble_yaml = load_settings_model(model_dir)
+    if isinstance(model_dir_list[0],WaVoLightningModule):
+        model_list = model_dir_list
+    else:
+        if dummy:
+            base_model = load_model(model_dir=model_dir_list[0])
+            model_list = len(model_dir_list)*[base_model]
+        else:
+            model_list = [load_model(model_dir=Path(p)) for p in model_dir_list]
+
+        for model in model_list:
+            model.eval()
+
+    checkpoint_path = next((model_dir / "checkpoints").iterdir())
+    model = WaVoLightningEnsemble.load_from_checkpoint(checkpoint_path=checkpoint_path,model_list=model_list)
+
+    model.gauge_idx = model_list[0].gauge_idx #hacky
+    model.in_size = model.max_in_size
+    model.eval()
+    return model, ensemble_yaml
+