diff --git a/configs/db_new.yaml b/configs/db_new.yaml
index 8b4c9a08b6c988f47fd0a49addd9e99579c4151d..a70db6f31a6d7bf8de71f7196fc8fdcac1b1e559 100644
--- a/configs/db_new.yaml
+++ b/configs/db_new.yaml
@@ -10,10 +10,10 @@ main_config:
   zrxp_out_folder : ../data/zrxp_out/
   #if True tries to put  all files in folder in the external forecast table
   load_sensor : False
-  load_zrxp : False
+  load_zrxp : True
   ensemble : True
   single : False
-  dummy : True
+  dummy : False
   #start: 2019-12-01 02:00
   #start: 2023-04-19 03:00
   #start: 2023-11-19 03:00
@@ -24,21 +24,6 @@ main_config:
   range: !!bool False
   zrxp: !!bool True
 gauge_configs:
-  - gauge: hollingstedt
-    model_folder: ../models_torch/hollingstedt_version_14
-    columns:
-      - 4466,SHum,vwsl 
-      - 4466,SHum,bfwls 
-      - 4466,AT,h.Cmd-2 
-      - 114069,S,15m.Cmd 
-      - 111111,S,5m.Cmd 
-      - 9520081,S_Tide,1m.Cmd 
-      - 9530010,S_Tide,1m.Cmd 
-      - 112211,Precip,h.Cmd 
-      - 112211,S,5m.Cmd
-    external_fcst:
-      - 9530010,S_Tide,1m.Cmd
-      - 112211,Precip,h.Cmd
   - gauge: 114547,S,60m.Cmd
     model_folder: ../models_torch/tarp_version_49/
     columns:
@@ -52,6 +37,21 @@ gauge_configs:
     external_fcst:
       - 114547,Precip,h.Cmd
       #114547,Precip,h.Cmd : Tarp
+  - gauge: 112211,S,5m.Cmd
+    model_folder: ../models_torch/hollingstedt_version_37
+    columns:
+      - 4466,SHum,vwsl 
+      - 4466,SHum,bfwls 
+      - 4466,AT,h.Cmd-2 
+      - 114069,S,15m.Cmd 
+      - 111111,S,5m.Cmd 
+      - 9520081,S_Tide,1m.Cmd 
+      - 9530010,S_Tide,1m.Cmd 
+      - 112211,Precip,h.Cmd 
+      - 112211,S,5m.Cmd
+    external_fcst:
+      - 9530010,S_Tide,1m.Cmd
+      - 112211,Precip,h.Cmd
   - gauge: 114069,Precip,h.Cmd
     model_folder: ../models_torch/version_74_treia/
     columns:
diff --git a/configs/db_test_nocuda.yaml b/configs/db_test_nocuda.yaml
index 46f14b33fd6c54ef086ae7d0b7940953da57dd3f..3301f97ba742c7265099f2c8a0743e181e2e2c37 100644
--- a/configs/db_test_nocuda.yaml
+++ b/configs/db_test_nocuda.yaml
@@ -10,20 +10,34 @@ sensor_folder : ../data/db_in/
 zrxp_out_folder : ../data/zrxp_out/
 #if True tries to put  all files in folder in the external forecast table
 load_sensor : False
-load_zrxp : True
+load_zrxp : False
 ensemble : True
 single : False
 dummy : True
 #start: 2019-12-01 02:00
 #start: 2023-04-19 03:00
 #start: 2023-11-19 03:00
-start: 2024-05-02 09:00
+start: 2024-09-13 09:00
 #end: !!str 2021-02-06
 # if range is true will try all values from start to end for predictions
 #range: !!bool True
 range: !!bool False
 zrxp: !!bool True
 ---
+document: 114547,S,60m.Cmd
+model_folder:
+- ../models_torch/tarp_version_49/
+columns:
+  - 4466,SHum,vwsl
+  - 4466,SHum,bfwls
+  - 4466,AT,h.Cmd
+  - 114435,S,60m.Cmd
+  - 114050,S,60m.Cmd
+  - 114547,S,60m.Cmd
+  - 114547,Precip,h.Cmd
+external_fcst:
+  114547,Precip,h.Cmd : Tarp
+---  
 document: 114069,Precip,h.Cmd
 model_folder: 
 - ../models_torch/version_74_treia/
diff --git a/notebooks/denmark_scraping.ipynb b/notebooks/denmark_scraping.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..5649506dfb201fc6dda0817edfa11ef17c7a08d2
--- /dev/null
+++ b/notebooks/denmark_scraping.ipynb
@@ -0,0 +1,157 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import requests\n",
+    "from bs4 import BeautifulSoup\n",
+    "from pathlib import Path\n",
+    "import pandas as pd\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import json\n",
+    "import h5py\n",
+    "import io"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "api_key_radar = \"_\"\n",
+    "api_key = \"_\"\n",
+    "\n",
+    "\n",
+    "url_collection_list = f\"https://dmigw.govcloud.dk/v2/climateData/collections?api-key={api_key}\"\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "r = requests.get(url_collection_list)\n",
+    "r.json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#requests.get(f\"https://dmigw.govcloud.dk/v2/climateData/collections/stationValue/items/?parameterId=acc_precip&limit=10&api-key={api_key}\").json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "requests.get(f\"https://dmigw.govcloud.dk/v2/climateData/collections/stationValue/items/?timeResolution=hour&parameterId=acc_precip&limit=10&api-key={api_key}\").json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "r = requests.get(f\"https://dmigw.govcloud.dk/v1/radardata/collections/pseudoCappi/items\",\n",
+    "             params={\"bbox\":\"7,54,16,58\",\n",
+    "                     \"limit\":\"1\",\n",
+    "                     \"api-key\":api_key_radar})\n",
+    "data_url = r.json()[\"features\"][0][\"asset\"][\"data\"][\"href\"]\n",
+    "data_url\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "r2 = requests.get(data_url)\n",
+    "\n",
+    "bio = io.BytesIO(r2.content)\n",
+    "f = h5py.File(bio, 'r')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "f.keys()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "f[\"dataset1\"].keys()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#w = f[\"dataset1\"][\"what\"]\n",
+    "data = f[\"dataset1\"][\"data1\"][\"data\"]\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data.shape\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#pd.read_hdf(\"/home/mspils/Downloads/eksn20240517_1420.Z_nn_S1.ps.500.wrk.h5\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "torch",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.17"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/pyplots.ipynb b/notebooks/pyplots.ipynb
index 6ad7cc281ce68636ea765669f7a3972c75ad39c6..39257e2f4d3c24addad029a6f8107b3f48d03abd 100644
--- a/notebooks/pyplots.ipynb
+++ b/notebooks/pyplots.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -11,101 +11,196 @@
     "import sys\n",
     "\n",
     "sys.path.insert(0, '../src')\n",
-    "import utils.utility as ut\n"
+    "import utils.utility as ut\n",
+    "from data_tools.datasets import TimeSeriesDataSet"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
-    "df = pd.read_csv(\"../../data/db_in/FoehrdenBarl3.csv\",index_col=0,parse_dates=True)\n"
+    "from pypots.data.generating import gene_physionet2012\n",
+    "physionet2012_dataset = gene_physionet2012(artificially_missing_rate=0)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2024-03-08 15:11:17.798239: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA\n",
-      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
-      "2024-03-08 15:11:19 [ERROR]: ❌ No module named 'torch_geometric'\n",
-      "Note torch_geometric is missing, please install it with 'pip install torch_geometric torch_scatter torch_sparse' or 'conda install -c pyg pyg pytorch-scatter pytorch-sparse'\n",
-      "2024-03-08 15:11:19 [ERROR]: ❌ name 'MessagePassing' is not defined\n",
-      "Note torch_geometric is missing, please install it with 'pip install torch_geometric torch_scatter torch_sparse' or 'conda install -c pyg pyg pytorch-scatter pytorch-sparse'\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
-    "from pypots.data.generating import gene_physionet2012"
+    "df = pd.read_csv(\"../../data/db_in/FoehrdenBarl3.csv\",index_col=0,parse_dates=True)\n",
+    "df = ut.fill_missing_values(df)\n"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2024-03-08 15:16:49 [INFO]: Loading the dataset physionet_2012 with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)...\n",
-      "2024-03-08 15:16:49 [INFO]: Starting preprocessing physionet_2012...\n",
-      "2024-03-08 15:16:49 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: \n",
-      "https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012\n",
-      "2024-03-08 15:16:49 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...\n",
-      "2024-03-08 15:16:49 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...\n",
-      "2024-03-08 15:16:49 [INFO]: Loaded successfully!\n"
-     ]
-    }
-   ],
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [],
    "source": [
-    "physionet2012_dataset = gene_physionet2012(artificially_missing_rate=0)"
+    "in_size = 144\n",
+    "out_size = 48\n",
+    "\n",
+    "\n",
+    "X_train = df[:int(0.7*len(df))].values\n",
+    "y_train = df[\"FoehrdenBarl_pegel_cm\"][:int(0.7*len(df))].values\n",
+    "ds_train = TimeSeriesDataSet(X_train, y_train, in_size, out_size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "X_train_stacked = []\n",
+    "y_train_stacked = []\n",
+    "\n",
+    "for el in ds_train:\n",
+    "    X_train_stacked.append(el[0])\n",
+    "    y_train_stacked.append(el[1])\n",
+    "\n",
+    "X_train_stacked=  np.stack(X_train_stacked)\n",
+    "y_train_stacked = np.stack(y_train_stacked)\n",
+    "n_features = X_train_stacked.shape[-1]\n",
+    "\n",
+    "dataset_for_training = {\"X\": X_train_stacked, \"y\": y_train_stacked}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "physionet2012_dataset.keys()\n",
+    "physionet2012_dataset[\"n_classes\"]\n",
+    "dataset_for_training = {\n",
+    "    \"X\": np.concatenate([physionet2012_dataset['train_X'], physionet2012_dataset['val_X']], axis=0),\n",
+    "    \"y\": np.concatenate([physionet2012_dataset['train_y'], physionet2012_dataset['val_y']], axis=0),\n",
+    "}\n",
+    "\n",
+    "dataset_for_testing = {\n",
+    "    \"X\": physionet2012_dataset['test_X'],\n",
+    "    \"y\": physionet2012_dataset['test_y'],\n",
+    "}"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 38,
    "metadata": {},
    "outputs": [
     {
-     "data": {
-      "text/plain": [
-       "dict_keys(['n_classes', 'n_steps', 'n_features', 'train_X', 'train_y', 'train_ICUType', 'val_X', 'val_y', 'val_ICUType', 'test_X', 'test_y', 'test_ICUType', 'scaler'])"
-      ]
-     },
-     "execution_count": 12,
-     "metadata": {},
-     "output_type": "execute_result"
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2024-03-08 15:30:09 [INFO]: Using the given device: cpu\n",
+      "2024-03-08 15:30:09 [INFO]: Model files will be saved to tutorial_results/clustering/crli/20240308_T153009\n",
+      "2024-03-08 15:30:09 [INFO]: Tensorboard file will be saved to tutorial_results/clustering/crli/20240308_T153009/tensorboard\n",
+      "2024-03-08 15:30:09 [INFO]: CRLI initialized with the given hyperparameters, the number of trainable parameters: 1,491,428\n"
+     ]
     }
    ],
    "source": [
-    "physionet2012_dataset.keys()"
+    "from pypots.optim import Adam\n",
+    "from pypots.clustering import CRLI\n",
+    "\n",
+    "# initialize the model\n",
+    "crli = CRLI(\n",
+    "    n_steps=in_size,\n",
+    "    n_features=n_features,\n",
+    "    n_clusters=10,\n",
+    "    n_generator_layers=2,\n",
+    "    rnn_hidden_size=256,\n",
+    "    rnn_cell_type=\"GRU\",\n",
+    "    decoder_fcn_output_dims=[256, 128],  # the output dimensions of layers in the decoder FCN.\n",
+    "    # Here means there are 3 layers. Leave it to default as None will results in\n",
+    "    # the FCN haveing only one layer.\n",
+    "    batch_size=32,\n",
+    "    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance\n",
+    "    epochs=10,\n",
+    "    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.\n",
+    "    # You can leave it to defualt as None to disable early stopping.\n",
+    "    patience=3,\n",
+    "    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when\n",
+    "    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.\n",
+    "    G_optimizer=Adam(lr=1e-3),\n",
+    "    D_optimizer=Adam(lr=1e-3),\n",
+    "    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.\n",
+    "    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.\n",
+    "    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed\n",
+    "    num_workers=0,\n",
+    "    # just leave it to default, PyPOTS will automatically assign the best device for you.\n",
+    "    # Set it to 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices.\n",
+    "    device='cpu',  \n",
+    "    # set the path for saving tensorboard and trained model files \n",
+    "    saving_path=\"tutorial_results/clustering/crli\",\n",
+    "    # only save the best model after training finished.\n",
+    "    # You can also set it as \"better\" to save models performing better ever during training.\n",
+    "    model_saving_strategy=\"best\",\n",
+    ")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 39,
    "metadata": {},
    "outputs": [
     {
-     "data": {
-      "text/plain": [
-       "2"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2024-03-08 15:30:39 [ERROR]: ❌ Exception: `predictions` mustn't contain NaN values, but detected NaN in it\n"
+     ]
+    },
+    {
+     "ename": "RuntimeError",
+     "evalue": "Training got interrupted. Model was not trained. Please investigate the error printed above.",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/clustering/crli/model.py:238\u001b[0m, in \u001b[0;36mCRLI._train_model\u001b[0;34m(self, training_loader, val_loader)\u001b[0m\n\u001b[1;32m    237\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mG_optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m--> 238\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    239\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining_object\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgenerator\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m    240\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    241\u001b[0m results[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgeneration_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mbackward()\n",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/clustering/crli/modules/core.py:92\u001b[0m, in \u001b[0;36m_CRLI.forward\u001b[0;34m(self, inputs, training_object, training)\u001b[0m\n\u001b[1;32m     89\u001b[0m l_G \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mbinary_cross_entropy_with_logits(\n\u001b[1;32m     90\u001b[0m     inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdiscrimination\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m missing_mask, weight\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m missing_mask\n\u001b[1;32m     91\u001b[0m )\n\u001b[0;32m---> 92\u001b[0m l_pre \u001b[38;5;241m=\u001b[39m \u001b[43mcalc_mse\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mimputation_latent\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmissing_mask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     93\u001b[0m l_rec \u001b[38;5;241m=\u001b[39m calc_mse(inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mreconstruction\u001b[39m\u001b[38;5;124m\"\u001b[39m], X, missing_mask)\n",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/utils/metrics/error.py:159\u001b[0m, in \u001b[0;36mcalc_mse\u001b[0;34m(predictions, targets, masks)\u001b[0m\n\u001b[1;32m    158\u001b[0m \u001b[38;5;66;03m# check shapes and values of inputs\u001b[39;00m\n\u001b[0;32m--> 159\u001b[0m lib \u001b[38;5;241m=\u001b[39m \u001b[43m_check_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredictions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmasks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m masks \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/utils/metrics/error.py:36\u001b[0m, in \u001b[0;36m_check_inputs\u001b[0;34m(predictions, targets, masks, check_shape)\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;66;03m# check NaN\u001b[39;00m\n\u001b[0;32m---> 36\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m lib\u001b[38;5;241m.\u001b[39misnan(\n\u001b[1;32m     37\u001b[0m     predictions\n\u001b[1;32m     38\u001b[0m )\u001b[38;5;241m.\u001b[39many(), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`predictions` mustn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt contain NaN values, but detected NaN in it\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     39\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m lib\u001b[38;5;241m.\u001b[39misnan(\n\u001b[1;32m     40\u001b[0m     targets\n\u001b[1;32m     41\u001b[0m )\u001b[38;5;241m.\u001b[39many(), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`targets` mustn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt contain NaN values, but detected NaN in it\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
+      "\u001b[0;31mAssertionError\u001b[0m: `predictions` mustn't contain NaN values, but detected NaN in it",
+      "\nDuring handling of the above exception, another exception occurred:\n",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[39], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mcrli\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset_for_training\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/clustering/crli/model.py:375\u001b[0m, in \u001b[0;36mCRLI.fit\u001b[0;34m(self, train_set, val_set, file_type)\u001b[0m\n\u001b[1;32m    367\u001b[0m     val_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m    368\u001b[0m         val_set,\n\u001b[1;32m    369\u001b[0m         batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_size,\n\u001b[1;32m    370\u001b[0m         shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    371\u001b[0m         num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_workers,\n\u001b[1;32m    372\u001b[0m     )\n\u001b[1;32m    374\u001b[0m \u001b[38;5;66;03m# Step 2: train the model and freeze it\u001b[39;00m\n\u001b[0;32m--> 375\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtraining_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    376\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mload_state_dict(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_model_dict)\n\u001b[1;32m    377\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39meval()  \u001b[38;5;66;03m# set the model as eval status to freeze it.\u001b[39;00m\n",
+      "File \u001b[0;32m~/anaconda3/envs/pytorch/lib/python3.10/site-packages/pypots/clustering/crli/model.py:332\u001b[0m, in \u001b[0;36mCRLI._train_model\u001b[0;34m(self, training_loader, val_loader)\u001b[0m\n\u001b[1;32m    330\u001b[0m logger\u001b[38;5;241m.\u001b[39merror(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m❌ Exception: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    331\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_model_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 332\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    333\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining got interrupted. Model was not trained. Please investigate the error printed above.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    334\u001b[0m     )\n\u001b[1;32m    335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    336\u001b[0m     \u001b[38;5;167;01mRuntimeWarning\u001b[39;00m(\n\u001b[1;32m    337\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining got interrupted. Please investigate the error printed above.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    338\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel got trained and will load the best checkpoint so far for testing.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    339\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf you don\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt want it, please try fit() again.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    340\u001b[0m     )\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: Training got interrupted. Model was not trained. Please investigate the error printed above."
+     ]
     }
    ],
    "source": [
-    "physionet2012_dataset[\"n_classes\"]"
+    "crli.fit(train_set=dataset_for_training)"
    ]
   },
   {
diff --git a/src/predict_database.py b/src/predict_database.py
index ab30b028f4508aad912fbcf66d00f12d25ff2bdb..19207e7ec366be7a8ee12d5573f448a9beb783d9 100644
--- a/src/predict_database.py
+++ b/src/predict_database.py
@@ -68,14 +68,14 @@ def _prepare_logging_2(con) -> None:
         return record
     logging.setLogRecordFactory(record_factory)
 
-    logFormatter = logging.Formatter(
+    log_formatter = logging.Formatter(
         "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S"
     )
-    dbHandler = OracleDBHandler(con)
-    dbHandler.setFormatter(logFormatter)
+    db_handler = OracleDBHandler(con)
+    db_handler.setFormatter(log_formatter)
 
-    dbHandler.setLevel(logging.INFO)
-    logging.getLogger().addHandler(dbHandler)
+    db_handler.setLevel(logging.INFO)
+    logging.getLogger().addHandler(db_handler)
     logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:])
 
 
@@ -139,6 +139,5 @@ def main(passed_args) -> None:
 
 
 if __name__ == "__main__":
-    # TODO argparse einschränken auf bestimmte pegel?
-    args = parse_args()
-    main(args)  # parse_args()
+    my_args = parse_args()
+    main(my_args)  # parse_args()
diff --git a/src/utils/db_tools.py b/src/utils/db_tools.py
index 65af2e2533853a74b1b2acd76919407d62b38eea..7692012c7d8ab4818accb00009494e111b355816 100644
--- a/src/utils/db_tools.py
+++ b/src/utils/db_tools.py
@@ -5,7 +5,7 @@ A collection of classes and functions for interacting with the oracle based Wavo
 import logging
 from datetime import datetime
 from pathlib import Path
-from typing import List, Tuple
+from typing import List
 import warnings
 
 # from warnings import deprecated (works from python 3.13 on)
@@ -14,15 +14,11 @@ import numpy as np
 import oracledb
 import pandas as pd
 from sqlalchemy import create_engine
-from sqlalchemy import select, bindparam, between, update, insert
+from sqlalchemy import select, bindparam, between, update
 from sqlalchemy.orm import Session
-from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2, NUMBER
+from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2
 from sqlalchemy.exc import IntegrityError
 import torch
-from torch.utils.data import Dataset, DataLoader
-import lightning.pytorch as pl
-
-from tqdm import tqdm
 
 import utils.helpers as hp
 from utils.orm_classes import (
@@ -37,6 +33,7 @@ from utils.orm_classes import (
 
 # pylint: disable=unsupported-assignment-operation
 # pylint: disable=unsubscriptable-object
+# pylint: disable=line-too-long
 
 
 class OracleWaVoConnection:
@@ -65,12 +62,11 @@ class OracleWaVoConnection:
         """
 
             # self.cur_model_folder = model_folder
-        
         model = hp.load_model(gauge_config["model_folder"])
         model.eval()
 
         created = datetime.now()
-        for end_time in tqdm(self.times):
+        for end_time in self.times:
             # load sensor sensor input data
             df_base_input = self.load_input_db(
                 in_size=model.in_size,
@@ -83,15 +79,17 @@ class OracleWaVoConnection:
 
             # predict
             # Get all external forecasts for this gauge
-            #stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name.in_(bindparam("external_fcst")))
-            #params = {"external_fcst" : gauge_config["external_fcst"]}
-            #with Session(self.engine) as session:
-            #    x = session.scalars(statement=stmt, params=params).fetchall()
-            #    sensor_names = [el.sensor_name for el in x]
-            #params2 = {"ext_forecasts" : sensor_names, "tstamp" : end_time}
+            stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name.in_(bindparam("external_fcst")))
+            params = {"external_fcst" : gauge_config["external_fcst"]}
+            with Session(self.engine) as session:
+                input_meta_data = session.scalars(statement=stmt, params=params).fetchall()
+
             stmst2 = select(InputForecasts).where(InputForecasts.sensor_name.in_(bindparam("ext_forecasts")),InputForecasts.tstamp == (bindparam("tstamp")))
             params2 = {"ext_forecasts" : gauge_config["external_fcst"], "tstamp" : end_time}
             df_temp = pd.read_sql(sql=stmst2,con=self.engine,index_col="tstamp",params=params2)
+            if len(gauge_config["external_fcst"]) != len(input_meta_data):
+                logging.error("Not all external can be found in InputForecastsMeta, only %s",[x.sensor_name for x in input_meta_data],extra={"gauge" : gauge_config["gauge"]})
+                return
 
             for member in self.members:
                 self.handle_member(
@@ -100,7 +98,9 @@ class OracleWaVoConnection:
                     df_base_input,
                     member,
                     end_time,
-                    created,
+                    input_meta_data,
+                    df_temp,
+                    created
                 )
 
 
@@ -111,6 +111,8 @@ class OracleWaVoConnection:
         df_base_input: pd.DataFrame,
         member: int,
         end_time,
+        input_meta_data,
+        df_temp,
         created: pd.Timestamp = None,
     ) -> None:
         """
@@ -123,25 +125,47 @@ class OracleWaVoConnection:
             df_base_input (DataFrame): The base input DataFrame with measure values, still need the actual precipitation forecast.
             member (int): The member identifier.
             end_time (pd.Timestamp): The timestamp of the forecast.
+            input_meta_data (list): List of external forecast metadata.
+            df_temp (DataFrame): DataFrame with external forecasts.
             created (pd.Timestamp): The timestamp of the start of the forecast creation.
 
         Returns:
             None
         """
-        if member == -1:
-            df_input = df_base_input.fillna(0)
+        try:
+            if member == -1:
+                df_input = df_base_input.fillna(0)
+            else:
+                # replace fake forecast with external forecast
+                df_input = df_base_input.copy()
+                for input_meta in input_meta_data:
+                    if input_meta.ensemble_members == 1:
+                        temp_member = 0
+                    elif input_meta.ensemble_members == 21:
+                        temp_member = member
+
+                    fcst_data = df_temp[(df_temp["member"]== temp_member) & (df_temp["sensor_name"] == input_meta.sensor_name)]
+                    if len(fcst_data) == 0:
+                        raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}")
+
+                    # Replace missing values with forecasts.
+                    # It is guaranteed that the only missing values are in the "future" (after end_time)
+                    assert df_input[:-48].isna().sum().sum() == 0
+                    nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48)
+                    nan_indices = nan_indices[nan_indices >= 0]
+                    nan_indices2 = nan_indices + (144 - 48)
+                    col_idx = df_input.columns.get_loc(input_meta.sensor_name)
+                    df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices]
+
+        except LookupError as e:
+            logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]})
+            y = torch.Tensor(48).fill_(np.nan)
         else:
-
-
-            # replace fake forecast with external forecast
-            df_input = self.merge_synth_fcst(df_base_input, member, end_time, gauge_config,created)
-        if df_input is None:
-            return
-
-        y = pred_single_db(model, df_input)
-        self.insert_forecast(
-            y, gauge_config["gauge"], gauge_config["model_folder"].name, end_time, member, created
-        )
+            y = pred_single_db(model, df_input)
+        finally:
+            self.insert_forecast(
+                y, gauge_config["gauge"], gauge_config["model_folder"].name, end_time, member, created
+            )
 
     def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame:
         """
@@ -229,16 +253,6 @@ class OracleWaVoConnection:
                 df_main.columns,extra={"gauge":gauge_config["gauge"]})
             logging.error(e.args[0])
             df_input = None
-        #except LookupError as e:
-        #    if df_input is not None:
-        #        logging.error(
-        #            "No data for the chosen timeperiod up to %s and columns in the database.",
-        #            end_time,
-        #            extra={"gauge":gauge_config["gauge"]}
-        #        )
-        #        df_input = None
-        #    logging.error(e.args[0])
-
 
         if df_input is None:
             logging.error("Input sensordata could not be loaded, inserting empty forecasts",extra={"gauge":gauge_config["gauge"]})
@@ -260,57 +274,6 @@ class OracleWaVoConnection:
 
         return df_input
 
-    def merge_synth_fcst(
-        self,
-        df_base_input: pd.DataFrame,
-        member: int,
-        end_time: pd.Timestamp,
-        gauge_config: dict,
-        created: pd.Timestamp = None,
-    ) -> pd.DataFrame:
-        """
-        Merges external forecasts into the base input dataframe by replacing the old wrong values.
-
-        Args:
-            df_base_input (pd.DataFrame): The base input dataframe.
-            member (int): The member identifier.
-            end_time (pd.Timestamp): The timestamp of the forecast.
-            gauge_config (dict): Settings for the gauge, especially the external forecasts and the gauge name.
-            created (pd.Timestamp): The timestamp of the start of the forecast creation.
-
-        Returns:
-            pd.DataFrame: The merged dataframe with external forecasts.
-        """
-        #TODO: besprechen ob wir einfach alle spalten durch vorhersagen ergänzen falls vorhanden?
-        df_input = df_base_input.copy()
-
-        for col in gauge_config["external_fcst"]:
-            try:
-                ext_fcst = self.get_ext_forecast(col, member, end_time)
-            except AttributeError:
-                logging.error(
-                    "External forecast %s time %s ensemble %s is missing",
-                    col,
-                    end_time,
-                    member,
-                    extra={"gauge":gauge_config["gauge"]}
-                )
-                # if some external forecasts are missing, insert NaNs and skip prediction
-                y = torch.Tensor(48).fill_(np.nan)
-                self.insert_forecast(y, gauge_config["gauge"], gauge_config["model_folder"].name, end_time, member, created)
-                return None
-            else:
-                # Replace missing values with forecasts.
-                # It is guaranteed that the only missing values are in the "future" (after end_time)
-                assert df_input[:-48].isna().sum().sum() == 0
-                nan_indices = np.where(df_input[col].isna())[0] - (144 - 48)
-                nan_indices = nan_indices[nan_indices >= 0]
-                nan_indices2 = nan_indices + (144 - 48)
-                df_input.iloc[nan_indices2, df_input.columns.get_loc(col)] = np.array(
-                    ext_fcst
-                )[nan_indices]
-
-        return df_input
 
     def add_zrxp_data(self, zrxp_folders: List[Path]) -> None:
         """Adds zrxp data to the database
@@ -332,29 +295,18 @@ class OracleWaVoConnection:
                     zrxp_file, skiprows=3, header=None, sep=" ", parse_dates=[0]
                 )
 
-                # zrxp_time = df_zrxp.iloc[0, 0]
-                # member = int(df_zrxp.iloc[0, 2])
-                # try:
-                #    # self.insert_external_forecast(zrxp_time, fcst_name, df_zrxp[3], member)
-                #    self.insert_external_forecast(
-                #        zrxp_time, vhs_gebiet, df_zrxp[3], member
-                #    )
                 try:
                     if len(df_zrxp.columns) == 4:
                         zrxp_time = df_zrxp.iloc[0, 0]
                         member = int(df_zrxp.iloc[0, 2])
-                        self.insert_external_forecast(
-                            zrxp_time, vhs_gebiet, df_zrxp[3], member
-                        )
+                        forecast = df_zrxp[3]
+
                     else:
                         zrxp_time = df_zrxp.iloc[0, 0] - pd.Timedelta(hours=1)
                         member = 0
-                        # self.insert_external_forecast(zrxp_time, fcst_name, df_zrxp[3], member)
-                        self.insert_external_forecast(
-                            zrxp_time, vhs_gebiet, df_zrxp[1], member
-                        )
+                        forecast = df_zrxp[1]
+                    self.insert_external_forecast(zrxp_time, vhs_gebiet, forecast, member)
                 except oracledb.IntegrityError as e:
-
                     if e.args[0].code == 2291:
                         logging.warning(
                             "%s Does the sensor_name %s exist in MODELL_SENSOR.VHS_GEBIET?",
@@ -381,10 +333,6 @@ class OracleWaVoConnection:
         Returns:
             List(int): The external forecast data.
         """
-        # TODO check new column in MODELL_SENSOR
-
-        #TODO this would be more efficient outside the member loop
-        #TODO error handling
         stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name == (bindparam("sensor_name")))
         params = {"sensor_name" : sensor_name}
         with Session(self.engine) as session:
@@ -393,10 +341,6 @@ class OracleWaVoConnection:
         if input_meta is None:
             raise AttributeError(f"Forecast {sensor_name} not found in Table InputForecastsMeta")
         
-        #sensor = self.get_sensor_name(vhs_gebiet)  # maybe check for none?
-        #if sensor is None:
-        #    raise AttributeError(f"Sensor {vhs_gebiet} not found in database")
-        
         stmt2 = select(InputForecasts).where(
             InputForecasts.tstamp == bindparam("tstamp"),
             InputForecasts.sensor_name == bindparam("sensor_name"),
@@ -429,7 +373,6 @@ class OracleWaVoConnection:
         Returns:
             None
         """
-
         sensor = self.get_sensor_name(vhs_gebiet)
         if sensor is None:
             return
@@ -558,7 +501,6 @@ class OracleWaVoConnection:
         print(
             "THIS METHOD IS DEPRECATED, the LFU created tables, not sure how to map their constraints to the oracle_db library"
         )
-        return
 
         with self.con.cursor() as cursor:
             table_list = [
@@ -671,7 +613,7 @@ class OracleWaVoConnection:
         return [pd.to_datetime(self.main_config["start"])]
 
     def get_sensor_name(self, vhs_gebiet: str) -> Sensor:
-        """Returns the sensor name based on the vhs_gebiet, by looking it up in the table ModellSensor and checking for uniqueness of combination and existence.
+        """Returns the sensor name based on the vhs_gebiet, by looking it up in the table INPUT_FORECASTS_META
 
         Args:
             vhs_gebiet (str): The vhs_gebiet name.
@@ -680,28 +622,12 @@ class OracleWaVoConnection:
             Sensor: A Sensor object (a row from the table Sensor).
         """
 
+        stmt = select(InputForecastsMeta).where(InputForecastsMeta.vhs_gebiet == bindparam("vhs_gebiet"))
+        params = {"vhs_gebiet" : vhs_gebiet}
+        with Session(self.engine) as session:
+            input_meta = session.scalar(statement=stmt, params=params)
 
-        with Session(bind=self.engine) as session:
-            # get all sensor names for the vhs_gebiet from the model_sensor table and check if there is only one.
-            stmt = select(ModellSensor.sensor_name).where(
-                ModellSensor.vhs_gebiet == bindparam("vhs_gebiet")
-            )
-            sensor_names = set(
-                session.scalars(stmt, params={"vhs_gebiet": vhs_gebiet}).all()
-            )
-            if len(sensor_names) == 0:
-                logging.warning("No sensor_name found for %s", vhs_gebiet)
-                self.not_found.append(vhs_gebiet)
-
-                return
-            elif len(sensor_names) > 1:
-                logging.warning("Multiple sensor_names found for %s", vhs_gebiet)
-                return
-            sensor_name = list(sensor_names)[0]
-            sensor = session.scalar(
-                select(Sensor).where(Sensor.sensor_name == sensor_name)
-            )
-        return sensor
+        return input_meta
 
     # @deprecated("This method is deprecated, data can be loaded directly from wiski (Not part of this package).")
     def add_sensor_data(self, sensor_folder, force=False) -> None:
@@ -804,7 +730,7 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor:
         warnings.filterwarnings(action="ignore", category=UserWarning)
         if model.differencing == 1:
             df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0)
-
+        #TODO time embedding stuff
         x = model.scaler.transform(
             df_input.values
         )  # TODO values/df column names wrong/old
diff --git a/src/utils/helpers.py b/src/utils/helpers.py
index 7848312a5a75395c3d0c6011a107c0354a91b75b..950f0424a9b789eacb3d491a014d712374c994a8 100644
--- a/src/utils/helpers.py
+++ b/src/utils/helpers.py
@@ -121,16 +121,25 @@ def load_model(model_dir: Path) -> WaVoLightningModule:
                 optimizer="adam",
             )
         except TypeError:
-            model = WaVoLightningModule.load_from_checkpoint(
-                checkpoint_path=next((model_dir / "checkpoints").iterdir()),
-                map_location=map_location,
-                scaler=config["scaler"],
-                optimizer="adam",
-                differencing=0,
-                gauge_idx=-1,
-                embed_time=False
-            )
-
+            try:
+                model = WaVoLightningModule.load_from_checkpoint(
+                    checkpoint_path=next((model_dir / "checkpoints").iterdir()),
+                    map_location=map_location,
+                    scaler=config["scaler"],
+                    optimizer="adam",
+                    gauge_idx=-1,
+                    embed_time=False
+                )
+            except TypeError:
+                model = WaVoLightningModule.load_from_checkpoint(
+                    checkpoint_path=next((model_dir / "checkpoints").iterdir()),
+                    map_location=map_location,
+                    scaler=config["scaler"],
+                    optimizer="adam",
+                    differencing=0,
+                    gauge_idx=-1,
+                    embed_time=False
+                )
     return model