diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..6e355b304044ca964ef1ccd853d20cf5dafd4e2a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "notebooks"] + path = notebooks + url = git@git.informatik.uni-kiel.de:ag-ins/projects/kiwavo/notebooks.git diff --git a/configs/bsh_export.yaml b/configs/bsh_export.yaml new file mode 100644 index 0000000000000000000000000000000000000000..015361224fb5a7a42d90253363bbd9bc035df231 --- /dev/null +++ b/configs/bsh_export.yaml @@ -0,0 +1,8 @@ +source_file: ../data/MOS.zip +target_folder: ../data/bsh/ +target_folder_wiski: ../data/4WISKI/ +db_params: + user: c##mspils + #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen. + password: cobalt_deviancy + dsn: localhost/XE \ No newline at end of file diff --git a/configs/db_import_model.yaml b/configs/db_import_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aad2e3c6dcc5a9cfabffd5bccee6eac5ac69ebef --- /dev/null +++ b/configs/db_import_model.yaml @@ -0,0 +1,27 @@ +db_params: + user: c##mspils + #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen. + password: cobalt_deviancy + dsn: localhost/XE +models: + - gauge: mstr2,parm_sh,ts-sh + model_name : mein_model_unfug + model_folder: ../models_torch/tarp_version_49/ + columns: + - a,b,c + - mstr2,parm_sh,ts-sh + external_fcst: + - a,b,c + - gauge: 114547,S,60m.Cmd + model_name : mein_model + model_folder: ../models_torch/tarp_version_49/ + columns: + - 4466,SHum,vwsl + - 4466,SHum,bfwls + - 4466,AT,h.Cmd + - 114435,S,60m.Cmd + - 114050,S,60m.Cmd + - 114547,S,60m.Cmd + - 114547,Precip,h.Cmd + external_fcst: + - 114547,Precip,h.Cmd \ No newline at end of file diff --git a/configs/db_new.yaml b/configs/db_new.yaml index 8b4c9a08b6c988f47fd0a49addd9e99579c4151d..dd07f1acb2c833fe4b7d033e2a46459c8d13e822 100644 --- a/configs/db_new.yaml +++ b/configs/db_new.yaml @@ -2,30 +2,35 @@ main_config: max_missing : 120 db_params: user: c##mspils + #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen. password: cobalt_deviancy dsn: localhost/XE # Auch als Liste möglich zrxp_folder : ../data/dwd_ens_zrpx/ sensor_folder : ../data/db_in/ zrxp_out_folder : ../data/zrxp_out/ + #If True exports the forecasts if successful to zrxp + export_zrxp : True #if True tries to put all files in folder in the external forecast table + fake_external : False load_sensor : False - load_zrxp : False + load_zrxp : True ensemble : True - single : False - dummy : True + single : True + dummy : False #start: 2019-12-01 02:00 #start: 2023-04-19 03:00 #start: 2023-11-19 03:00 start: 2024-09-13 09:00 + end: !!str 2024-09-14 09:00 #end: !!str 2021-02-06 # if range is true will try all values from start to end for predictions - #range: !!bool True - range: !!bool False + range: !!bool True + #range: !!bool False zrxp: !!bool True gauge_configs: - - gauge: hollingstedt - model_folder: ../models_torch/hollingstedt_version_14 + - gauge: 112211,S,5m.Cmd + model_folder: ../models_torch/hollingstedt_version_37 columns: - 4466,SHum,vwsl - 4466,SHum,bfwls diff --git a/configs/db_test_nocuda.yaml b/configs/db_test_nocuda.yaml index 46f14b33fd6c54ef086ae7d0b7940953da57dd3f..3301f97ba742c7265099f2c8a0743e181e2e2c37 100644 --- a/configs/db_test_nocuda.yaml +++ b/configs/db_test_nocuda.yaml @@ -10,20 +10,34 @@ sensor_folder : ../data/db_in/ zrxp_out_folder : ../data/zrxp_out/ #if True tries to put all files in folder in the external forecast table load_sensor : False -load_zrxp : True +load_zrxp : False ensemble : True single : False dummy : True #start: 2019-12-01 02:00 #start: 2023-04-19 03:00 #start: 2023-11-19 03:00 -start: 2024-05-02 09:00 +start: 2024-09-13 09:00 #end: !!str 2021-02-06 # if range is true will try all values from start to end for predictions #range: !!bool True range: !!bool False zrxp: !!bool True --- +document: 114547,S,60m.Cmd +model_folder: +- ../models_torch/tarp_version_49/ +columns: + - 4466,SHum,vwsl + - 4466,SHum,bfwls + - 4466,AT,h.Cmd + - 114435,S,60m.Cmd + - 114050,S,60m.Cmd + - 114547,S,60m.Cmd + - 114547,Precip,h.Cmd +external_fcst: + 114547,Precip,h.Cmd : Tarp +--- document: 114069,Precip,h.Cmd model_folder: - ../models_torch/version_74_treia/ diff --git a/configs/icon_exp_ensemble.yaml b/configs/icon_exp_ensemble.yaml index 27338164b0b4cc002aa2f22c9c15bec9c040bb38..d32e16129f695b4f56c1a5888046263f10b9ff0f 100644 --- a/configs/icon_exp_ensemble.yaml +++ b/configs/icon_exp_ensemble.yaml @@ -6,4 +6,9 @@ mode: ensemble force_index_calculation: False log_file: ../data/ICON_ENSEMBLE_EXPORT/icon_export.log tidy_up: True -debug: False \ No newline at end of file +debug: False +db_params: + user: c##mspils + #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen. + password: cobalt_deviancy + dsn: localhost/XE \ No newline at end of file diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000000000000000000000000000000000000..cef4d6e2754b97b2af1c6c672400d912e4293c51 --- /dev/null +++ b/src/app.py @@ -0,0 +1,706 @@ +import os +from datetime import datetime, timedelta,timezone + +import dash_bootstrap_components as dbc +import oracledb +import pandas as pd +from dash import Dash, dash_table, dcc, html +from dash.dependencies import MATCH, Input, Output, State +from sqlalchemy import and_, create_engine, desc, func, select,distinct +from sqlalchemy.orm import Session + +from dashboard.styles import TABLE_STYLE, TAB_STYLE +#from dash_tools.layout_helper import create_collapsible_section +from dashboard.tables_and_plots import (create_historical_plot, + create_historical_table, + create_inp_forecast_status_table, + create_log_table, create_metrics_plots, + create_model_forecasts_plot, + create_status_summary_chart) + +from dashboard.constants import NUM_RECENT_FORECASTS, NUM_RECENT_LOGS, BUFFER_TIME,LOOKBACK_EXTERNAL_FORECAST_EXISTENCE +from utils.db_tools.orm_classes import (InputForecasts, Log, Metric, Modell, + ModellSensor, PegelForecasts, + SensorData) + + + +#TODO add logging to dashboard? +class ForecastMonitor: + def __init__(self, username, password, dsn): + self.db_params = { + "user": username, + "password": password, + "dsn": dsn + } + try: + self.con = oracledb.connect(**self.db_params) + self.engine = create_engine("oracle+oracledb://", creator=lambda: self.con) + except oracledb.OperationalError as e: + print(f"Error connecting to database: {str(e)}") + self.con = None + self.engine = None + + self.app = Dash(__name__, + suppress_callback_exceptions=True, + external_stylesheets=[dbc.themes.BOOTSTRAP], + assets_url_path='dashboard') + self.setup_layout() + self.setup_callbacks() + + + def _make_section(self,title,section_id,view_id): + # Input Forecasts Section + section_card = dbc.Card([ + dbc.CardHeader([ + dbc.Button( + children=[ + html.Span(title, style={'flex': '1'}), + html.Span("▼", className="ms-2") + ], + color="link", + id={'type': 'collapse-button', 'section': section_id}, + className="text-decoration-none text-dark h5 mb-0 w-100 d-flex align-items-center", + ) + ]), + dbc.Collapse( + dbc.CardBody(html.Div(id=view_id)), + id={'type': 'collapse-content', 'section': section_id}, + is_open=True + ) + ], className="mb-4") + return section_card + + + def setup_layout(self): + # Header + header = html.Div([ + # Main header bar + dbc.Navbar( + [ + dbc.Container([ + dbc.Row([ + dbc.Col(html.H1("Pegel Dashboard", + className="mb-0 text-white d-flex align-items-center"#className="mb-0 text-white" + ), + width="auto"), + dbc.Col(dbc.Tabs([ + dbc.Tab(label="Modell Monitoring",tab_id="model-monitoring",**TAB_STYLE), + dbc.Tab(label="System Logs",tab_id="system-logs",**TAB_STYLE) + ], + id="main-tabs", + active_tab="model-monitoring", + className="nav-tabs border-0", + ), + width="auto", + className="d-flex align-items-end" + ) + ]), + ], + fluid=True, + className="px-4"), + ], + color="primary", + dark=True + )], + className="mb-4") + + + # Main content + self.model_monitoring_tab = dbc.Container([ + # Row for the entire dashboard content + dbc.Row([ + # Left column - fixed width + dbc.Col([ + dbc.Card([ + dbc.CardHeader(html.H3("Modell Status", className="h5 mb-0")), + dbc.CardBody([ + html.Div(id='model-status-table'), + dcc.Interval( + id='status-update', + interval=600000 # Update every 10 minutes + ) + ]) + ], className="mb-4"), + dbc.Card([ + dbc.CardHeader(html.H3("Status Übersicht", className="h5 mb-0")), + dbc.CardBody([dcc.Graph(id='status-summary-chart')]) + ]) + ], xs=12, md=5, className="mb-4 mb-md-0"), # Responsive column + # Right column + dbc.Col([ + dcc.Store(id='current-sensor-names'), + # Model Forecasts Section + self._make_section("Vorhersagen",'forecasts','fcst-view'), + # Historical Data Section + self._make_section("Messwerte",'historical','historical-view'), + # Input Forecasts Section + #self._make_section("Input Forecasts",'input-forecasts','inp-fcst-view'), + # Recent Logs Section + self._make_section("Logs",'logs','log-view'), + # Metrics Section + self._make_section("Metriken", 'metrics', 'metrics-view'), + ], xs=12, md=7) # Responsive column + ]) + ], fluid=True, className="py-4") + + # System Logs Tab Content + self.system_logs_tab = dbc.Container([ + dbc.Card([ + dbc.CardBody([ + html.H3("System Logs", className="mb-4"), + dcc.Interval( #TODO rausnehmen? + id='logs-update', + interval=30000 # Update every 30 seconds + ), + dash_table.DataTable( + id='system-log-table', + columns=[ + {'name': 'Zeitstempel', 'id': 'timestamp'}, + {'name': 'Level', 'id': 'level'}, + {'name': 'Pegel', 'id': 'sensor'}, + {'name': 'Nachricht', 'id': 'message'}, + {'name': 'Modul', 'id': 'module'}, + {'name': 'Funktion', 'id': 'function'}, + {'name': 'Zeile', 'id': 'line'}, + {'name': 'Exception', 'id': 'exception'} + ], + **TABLE_STYLE, + page_size=20, + page_action='native', + sort_action='native', + sort_mode='multi', + filter_action='native', + ) + ]) + ]) + ], fluid=True, className="py-4") + + # Add a container to hold the tab content + tab_content = html.Div( + id="tab-content", + children=[ + dbc.Container( + id="page-content", + children=[], + fluid=True, + className="py-4" + ) + ] + ) + # Main content with tabs + #main_content = dbc.Tabs([ + # dbc.Tab(model_monitoring_tab, label="Modell Monitoring", tab_id="model-monitoring"), + # dbc.Tab(system_logs_tab, label="System Logs", tab_id="system-logs") + #], id="main-tabs", active_tab="model-monitoring") + + # Combine all elements + self.app.layout = html.Div([ + header, + html.Div(id="tab-content", children=[ + self.model_monitoring_tab, + self.system_logs_tab + ], style={"display": "none"}), # Hidden container for both tabs + html.Div(id="visible-tab-content") # Container for the visible tab + ], className="min-vh-100 bg-light") + + # Combine all elements + #self.app.layout = html.Div([ + # header, + # main_content + #], className="min-vh-100 bg-light") + + def get_model_name_from_path(self, modelldatei): + if not modelldatei: + return None + return os.path.basename(modelldatei) + + def get_active_models_status(self): + try: + now = datetime.now(timezone.utc) + last_required_hour = now -BUFFER_TIME + last_required_hour = last_required_hour.replace(minute=0, second=0, microsecond=0) + while last_required_hour.hour % 3 != 0: + last_required_hour -= timedelta(hours=1) + + with Session(self.engine) as session: + active_models = session.query(Modell).filter(Modell.aktiv == 1).all() + + model_status = [] + for model in active_models: + actual_model_name = self.get_model_name_from_path(model.modelldatei) + if not actual_model_name: + continue + + current_forecast = session.query(PegelForecasts).filter( + and_( + PegelForecasts.model == actual_model_name, + PegelForecasts.tstamp == last_required_hour + ) + ).first() + + last_valid_forecast = session.query(PegelForecasts).filter( + and_( + PegelForecasts.model == actual_model_name, + PegelForecasts.h1 is not None + ) + ).order_by(PegelForecasts.tstamp.desc()).first() + + model_status.append({ + 'model_name': model.modellname, + 'actual_model_name': actual_model_name, + 'model_id': model.id, # Added for historical data lookup + 'sensor_name': last_valid_forecast.sensor_name if last_valid_forecast else None, + 'has_current_forecast': current_forecast is not None, + 'last_forecast_time': last_valid_forecast.tstamp if last_valid_forecast else None, + 'forecast_created': last_valid_forecast.created if last_valid_forecast else None, + }) + + return { + 'model_status': model_status, + 'last_check_time': now, + 'required_timestamp': last_required_hour + } + + except Exception as e: + print(f"Error getting model status: {str(e)}") + return None + + def get_recent_forecasts(self, actual_model_name): + """Get recent forecasts for a specific model and sensor""" + try: + subq = ( + select( + PegelForecasts, + func.row_number() + .over( + partition_by=[PegelForecasts.member], + order_by=PegelForecasts.tstamp.desc() + ).label('rn')) + .where(PegelForecasts.model == actual_model_name) + .subquery() + ) + + stmt = ( + select(subq) + .where(subq.c.rn <= NUM_RECENT_FORECASTS) + .order_by( + subq.c.member, + subq.c.tstamp.desc() + ) + ) + df = pd.read_sql(sql=stmt, con=self.engine) + df.drop(columns=['rn'], inplace=True) + df.dropna(inplace=True) + return df + + + except Exception as e: + print(f"Error getting recent forecasts: {str(e)}") + return pd.DataFrame() + + def get_input_forecasts(self, sensor_names): + """Get input forecasts for the last NUM_RECENT_FORECASTS+1 *3 hours for the given sensor names""" + try: + now = datetime.now(timezone.utc) + stmt = select(InputForecasts).where(InputForecasts.sensor_name.in_(sensor_names),InputForecasts.tstamp > (now - (NUM_RECENT_FORECASTS+1)*pd.Timedelta(hours=3))) + df = pd.read_sql(sql=stmt, con=self.engine) + + with Session(self.engine) as session: + existing = session.query(distinct(InputForecasts.sensor_name))\ + .filter(InputForecasts.sensor_name.in_(sensor_names))\ + .filter(InputForecasts.tstamp >= now - timedelta(days=LOOKBACK_EXTERNAL_FORECAST_EXISTENCE))\ + .all() + ext_forecast_names = [x[0] for x in existing] + + + #df.drop(columns=['rn'], inplace=True) + return df, ext_forecast_names + except Exception as e: + raise RuntimeError(f"Error getting input forecasts: {str(e)}") from e + + def get_recent_logs(self, sensor_name): + """Get the NUM_RECENT_LOGS most recent logs for a specific sensor""" + if sensor_name is None: + return [] + try: + with Session(self.engine) as session: + logs = session.query(Log).filter( + Log.gauge == sensor_name + ).order_by( + desc(Log.created) + ).limit(NUM_RECENT_LOGS).all() + + return [ + { + 'timestamp': log.created.strftime('%Y-%m-%d %H:%M:%S'), + 'level': log.loglevelname, + 'sensor': log.gauge, + 'message': log.message, + 'module': log.module, + 'function': log.funcname, + 'line': log.lineno, + 'exception': log.exception or '' + } + for log in logs + ] + except Exception as e: + print(f"Error getting logs: {str(e)}") + return [] + + def get_historical_data(self, model_id): + """Get last 144 hours of sensor data for all sensors associated with the model""" + try: + with Session(self.engine) as session: + # Get all sensors for this model + model_sensors = session.query(ModellSensor).filter( + ModellSensor.modell_id == model_id + ).all() + + sensor_names = [ms.sensor_name for ms in model_sensors] + + # Get last 144 hours of sensor data + time_threshold = datetime.now(timezone.utc) - timedelta(hours=144) - BUFFER_TIME + time_threshold= pd.to_datetime("2024-09-13 14:00:00.000") - timedelta(hours=144) #TODO rausnehmen + + stmt = select(SensorData).where( + SensorData.tstamp >= time_threshold, + SensorData.sensor_name.in_(sensor_names)) + + df = pd.read_sql(sql=stmt,con=self.engine,index_col="tstamp") + df = df.pivot(columns="sensor_name", values="sensor_value")[sensor_names] + + return df + + except Exception as e: + print(f"Error getting historical data: {str(e)}") + return pd.DataFrame() + + def get_model_metrics(self, model_id, metric_name='mae'): + """Get metrics for a specific model. Assumes that each model only has on set of metrics with the same name""" + + try: + + metric_names = [f"train_{metric_name}", f"test_{metric_name}", f"val_{metric_name}", + f"train_{metric_name}_flood", f"test_{metric_name}_flood", f"val_{metric_name}_flood"] + # Get metrics for the model + stmt = select(Metric).where(Metric.model_id == model_id,Metric.metric_name.in_(metric_names)) + + df = pd.read_sql(sql=stmt, con=self.engine) + + train_start,train_end = df[df["metric_name"] == f"train_{metric_name}"][["start_time","end_time"]].iloc[0] + test_start,test_end = df[df["metric_name"] == f"test_{metric_name}"][["start_time","end_time"]].iloc[0] + times = [train_start,train_end,test_start,test_end] + + # Group metrics by tag and timestep + df = df.pivot(index="timestep",columns="metric_name",values="value") + + return df, times + except Exception as e: + print(f"Error getting model metrics: {str(e)}") + return None, None + + + + + def setup_callbacks(self): + # Add this new callback at the start of setup_callbacks + @self.app.callback( + Output("visible-tab-content", "children"), + Input("main-tabs", "active_tab") + ) + def switch_tab(active_tab): + if active_tab == "model-monitoring": + return self.model_monitoring_tab + elif active_tab == "system-logs": + return self.system_logs_tab + return html.P("No tab selected") + + @self.app.callback( + [Output('model-status-table', 'children'), + Output('status-summary-chart', 'figure')], + Input('status-update', 'n_intervals') + ) + def update_dashboard(n): + if self.con is None: + try: + self.con = oracledb.connect(**self.db_params) + self.engine = create_engine("oracle+oracledb://", creator=lambda: self.con) + except oracledb.OperationalError: + return dbc.Alert("Error connecting to database", color="danger"), {} + + status = self.get_active_models_status() + if not status: + #return html.Div("Error fetching data"), go.Figure() + return dbc.Alert("Error fetching data", color="danger"), {} + + header = html.Div([ + html.H4(f"Status am {status['last_check_time'].strftime('%Y-%m-%d %H:%M:%S')}",className="h6 mb-2"), + html.P(f"Vorhersagen erwartet für: {status['required_timestamp'].strftime('%Y-%m-%d %H:%M:00')}",className="text-muted small") + ]) + + table = dash_table.DataTable( + id='status-table', + columns=[ + {'name': 'Modell', 'id': 'model_name'}, + {'name': 'Status', 'id': 'has_current_forecast'}, + {'name': 'Neueste Vorhersage', 'id': 'last_forecast_time'}, + {'name': 'Vorhersage erstellt', 'id': 'forecast_created'}, + {'name': 'Pegel', 'id': 'sensor_name'}, + {'name': 'model_id', 'id': 'model_id', 'hideable': True}, + {'name': 'actual_model_name', 'id': 'actual_model_name', 'hideable': True} + ], + data=[{ + 'model_name': row['model_name'], + 'has_current_forecast': '✓' if row['has_current_forecast'] else '✗', + 'last_forecast_time': row['last_forecast_time'].strftime('%Y-%m-%d %H:%M:%S') if row['last_forecast_time'] else 'No valid forecast', + 'forecast_created': row['forecast_created'].strftime('%Y-%m-%d %H:%M:%S') if row['forecast_created'] else 'N/A', + 'sensor_name': row['sensor_name'], + 'model_id': row['model_id'], + 'actual_model_name': row['actual_model_name'] + } for row in status['model_status']], + hidden_columns=['model_id', 'actual_model_name'], # Specify hidden columns here + **TABLE_STYLE, + row_selectable='single', + selected_rows=[], + ) + + fig = create_status_summary_chart(status) + + return html.Div([header, table]), fig + + + + @self.app.callback( + [Output('fcst-view', 'children'), + Output('historical-view', 'children'), + Output('log-view', 'children'), + Output('metrics-view', 'children'), + Output('current-sensor-names', 'data')], # Removed input-forecasts-view + [Input('status-table', 'selected_rows')], + [State('status-table', 'data')] + ) + def update_right_column(selected_rows, table_data): + if not selected_rows: + return (html.Div("Wähle ein Modell um Vorhersagen anzuzeigen."), + html.Div("Wähle ein Modell um Messwerte anzuzeigen."), + html.Div("Wähle ein Modell um Logs anzuzeigen."), + html.Div("Wähle ein Modell um Metriken anzuzeigen."), + None) # Removed input forecasts return + + selected_row = table_data[selected_rows[0]] + sensor_name = selected_row['sensor_name'] + model_id = selected_row['model_id'] + model_name = selected_row['model_name'] + actual_model_name = selected_row['actual_model_name'] + + # Get logs + logs = self.get_recent_logs(sensor_name) + log_table = create_log_table(logs) + log_view = html.Div([ + html.H4(f"Neueste Logs für {model_name}"), + log_table + ]) + + # Get historical data + df_historical = self.get_historical_data(model_id) + df_filtered = df_historical[df_historical.isna().any(axis=1)] + + sensor_names = list(df_historical.columns) + + if not df_historical.empty: + fig = create_historical_plot(df_historical, model_name) + historical_table = create_historical_table(df_filtered) + + + historical_view = html.Div([ + #html.H4(f"Messdaten {model_name}"), + dcc.Graph(figure=fig), + html.H4("Zeitpunkte mit fehlenden Messdaten", style={'marginTop': '20px', 'marginBottom': '10px'}), + html.Div(historical_table, style={'width': '100%', 'padding': '10px'}) + ]) + else: + historical_view = html.Div("Keine Messdaten verfügbar") + + + # Get forecast data + df_inp_fcst,ext_forecast_names = self.get_input_forecasts(sensor_names) + inp_fcst_table = create_inp_forecast_status_table(df_inp_fcst,ext_forecast_names) + df_forecasts = self.get_recent_forecasts(actual_model_name) + if not df_forecasts.empty: + fig_fcst = create_model_forecasts_plot(df_forecasts,df_historical,df_inp_fcst) + fcst_view = html.Div([ + html.H4(f"Pegelvorhersage Modell {model_name}"), + dcc.Graph(figure=fig_fcst), + html.H4("Status Eingangsvorhersagen", style={'marginTop': '20px', 'marginBottom': '10px'}), + html.Div(inp_fcst_table, style={'width': '100%', 'padding': '10px'}) + ]) + else: + fcst_view = html.Div("No forecasts available") + + # Get metrics + metrics_view = html.Div([ + html.H4(f"Modell Metriken für {model_name}"), + dbc.Row([ + dbc.Col([ + html.Label("Metric:", className="fw-bold mb-2"), + dcc.Dropdown( + id={'type': 'metric-selector', 'section': 'metrics'}, + options=[ + {'label': metric.upper(), 'value': metric} + for metric in ['mae', 'mse', 'nse', 'kge', 'p10', 'p20', 'r2', 'rmse', 'wape'] + ], + value='mae', + clearable=False, + className="mb-4" + ) + ], width=4) + ]), + html.Div(id='metric-plots-container'), + html.Div(id='metric-info-container', className="mt-4") + ]) + + return (fcst_view, + historical_view, + log_view, + metrics_view, + sensor_names) + + + @self.app.callback( + Output('historical-table', 'data'), + [Input('toggle-missing-values', 'n_clicks'), + Input('status-table', 'selected_rows')], + [State('status-table', 'data'), + State('toggle-missing-values', 'children')] + ) + def update_historical_table(n_clicks, selected_rows, table_data, button_text): + #print("update_historical_table") + if not selected_rows: + return [] + + selected_row = table_data[selected_rows[0]] + model_id = selected_row['model_id'] + + df_historical = self.get_historical_data(model_id) + if df_historical.empty: + return [] + + df_table = df_historical.reset_index() + df_table['tstamp'] = df_table['tstamp'].dt.strftime('%Y-%m-%d %H:%M') + + # Show filtered data by default, show all data after first click + if n_clicks and n_clicks % 2 == 1: + df_filtered = df_table + else: + df_filtered = df_table[df_table.isna().any(axis=1)] + + return df_filtered.to_dict('records') + + + + # Add new callback for updating metric plots + @self.app.callback( + [Output('metric-plots-container', 'children'), + Output('metric-info-container', 'children')], + [Input({'type': 'metric-selector', 'section': 'metrics'}, 'value'), + Input('status-table', 'selected_rows')], + [State('status-table', 'data')] + ) + def update_metric_plots(metric_name, selected_rows, table_data): + if not selected_rows or not metric_name: + return html.Div(), html.Div() + + selected_row = table_data[selected_rows[0]] + model_id = selected_row['model_id'] + + metrics_data, edge_times = self.get_model_metrics(model_id, metric_name) + if metrics_data.empty: + return html.Div("Keine Metriken verfügbar"), html.Div() + + fig = create_metrics_plots(metrics_data, metric_name) + + plots = html.Div([dcc.Graph(figure=fig, className="mb-4")]) + + info = html.Div([ + html.H5("Zeitraum der Metrikberechnung:", className="h6 mb-3"), + dbc.Row([ + dbc.Col([ + html.Strong("Start Trainingszeitraum: "), + html.Span(edge_times[0].strftime('%Y-%m-%d %H:%M:%S')) + ], width=6), + dbc.Col([ + html.Strong("Start Validierungszeitraum: "), + html.Span(edge_times[1].strftime('%Y-%m-%d %H:%M:%S')) + ], width=6), + dbc.Col([ + html.Strong("Start Testzeitraum: "), + html.Span(edge_times[2].strftime('%Y-%m-%d %H:%M:%S')) + ], width=6), + dbc.Col([ + html.Strong("Ende Testzeitraum: "), + html.Span(edge_times[3].strftime('%Y-%m-%d %H:%M:%S')) + ], width=6) + ]) + ], className="bg-light p-3 rounded") + + return plots, info + + @self.app.callback( + Output('system-log-table', 'data'), + Input('logs-update', 'n_intervals') + ) + def update_system_logs(n): + """Get all logs from the database""" + try: + with Session(self.engine) as session: + # Get last 1000 logs + logs = session.query(Log).order_by( + desc(Log.created) + ).limit(1000).all() + + return [{ + 'timestamp': log.created.strftime('%Y-%m-%d %H:%M:%S'), + 'level': log.loglevelname, + 'sensor': log.gauge, + 'message': log.message, + 'module': log.module, + 'function': log.funcname, + 'line': log.lineno, + 'exception': log.exception or '' + } for log in logs] + + except Exception as e: + print(f"Error getting system logs: {str(e)}") + return [] + + + # Update the collapse callback + @self.app.callback( + [Output({'type': 'collapse-content', 'section': MATCH}, 'is_open'), + Output({'type': 'collapse-button', 'section': MATCH}, 'children')], + [Input({'type': 'collapse-button', 'section': MATCH}, 'n_clicks')], + [State({'type': 'collapse-content', 'section': MATCH}, 'is_open'), + State({'type': 'collapse-button', 'section': MATCH}, 'children')] + ) + def toggle_collapse(n_clicks, is_open, current_children): + if n_clicks: + title = current_children[0]['props']['children'] + if is_open: + return False, [html.Span(title, style={'flex': '1'}), html.Span("►", className="ms-2")] + else: + return True, [html.Span(title, style={'flex': '1'}), html.Span("▼", className="ms-2")] + return is_open, current_children + + def run(self, host='0.0.0.0', port=8050, debug=True): + self.app.run(host=host, port=port, debug=debug) + +if __name__ == '__main__': + monitor = ForecastMonitor( + username="c##mspils", + password="cobalt_deviancy", + dsn="localhost/XE" + ) + #monitor.run(host="172.17.0.1", port=8050, debug=True) + #monitor.run(host="134.245.232.166", port=8050, debug=True) + monitor.run() \ No newline at end of file diff --git a/src/bsh_extractor_v2.py b/src/bsh_extractor_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..45335202cd2cdb456e70fe2d1f04fbff55fb11de --- /dev/null +++ b/src/bsh_extractor_v2.py @@ -0,0 +1,288 @@ +"""Simple script to extract the BSH forecast from the MOS.zip file and save it as .zrx files.""" +import argparse +import logging +from io import TextIOWrapper +from pathlib import Path +from typing import Tuple +from zipfile import ZipFile + +import numpy as np +import oracledb +import pandas as pd +import yaml + +from utils.db_tools.db_tools import OracleDBHandler + +logger = logging.getLogger("bsh_export") +logger.setLevel(logging.DEBUG) + +def setup_logging(con=None) -> None: + """Set up rotating log files + + Args: + con ([type], optional): Oracle connection. Defaults to None. + """ + ####### DB Logging + old_factory = logging.getLogRecordFactory() + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.gauge_id = kwargs.pop("gauge_id", None) + return record + logging.setLogRecordFactory(record_factory) + #### + log_formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S") + log_handler_stream = logging.StreamHandler() + log_handler_stream.setLevel(logging.INFO) + log_handler_stream.setFormatter(log_formatter) + logger.addHandler(log_handler_stream) + ####### DB Logging + if con is not None: + db_handler = OracleDBHandler(con) + db_handler.setFormatter(log_formatter) + db_handler.setLevel(logging.INFO) + logger.addHandler(db_handler) + + +def parse_args() -> argparse.Namespace: + """Parse all the arguments and provides some help in the command line""" + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Extracts predictions from ICON Ensembles into zrx files" + ) + + parser.add_argument( + "yaml_path", + metavar="yaml_path", + type=Path, + help="""The path to your config file. + Each files is expected to contain: + source_file: Path to the MOS.zip file + target_folder: Path to the folder where the zrx files should be saved + target_folder_wiski: Path to the folder where the zrx files for WISKI should be saved + db_params: Dictionary containing the connection parameters to the database + + Example: + source_file: ../data/MOS.zip + target_folder: ../data/bsh/ + target_folder_wiski: ../data/4WISKI/ + #db_params can alternatively contain a config_dir for thick mode + db_params: + user: __ + password: __ + dsn: __ + """, + ) + return parser.parse_args() + + +def get_config(passed_args: argparse.Namespace) -> dict: + """ + Retrieves the configurations from a YAML file. + + Args: + passed_args (argparse.Namespace): The command-line arguments. + + Returns: dict: The configuration + + """ + + with open(passed_args.yaml_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) + config["source_file"] = Path(config["source_file"]) + config["target_folder"] = Path(config["target_folder"]) + config["target_folder_wiski"] = Path(config["target_folder_wiski"]) + assert config["target_folder"].exists(), f"target_folder {config['target_folder']} does not exist" + assert config["target_folder_wiski"].exists(), f"target_folder_wiski {config['target_folder_wiski']} does not exist" + + return config + +def get_offset(hour:int) -> int: + """ Calculates how many values to skip from the input hour to get to an hour divisible by 3 + + Args: + hour (int): Hour part of the timestamp for the first predicted value + + Returns: + int: The offset + """ + if hour % 3 == 0: + return 1 + if hour % 3 == 1: + return 0 + if hour % 3 == 2: + return 2 + + +def get_dat_file_name(zip_file_name:Path) -> Tuple[str, str]: + """Finds the complete Path to the Eidersperrwerk AP data file and the timestamp of when the forecast was made. + + Args: + zip_file_name (Path): Path to the MOS.zip file + + Returns: + Tuple[str, str]: The complete path to the data file and the timestamp of the forecast + """ + with ZipFile(zip_file_name,'r')as my_zip: + with my_zip.open("MOS/Pegel-namen-nummern.txt") as pegelnames: + pegelnames_str = TextIOWrapper(pegelnames,encoding="cp1252") #ANSI + findNr4EiderSpAP(pegelnames_str) + + dat_file = next(filter(lambda x: x.startswith("MOS/pgl/pgl_664P_"),my_zip.namelist())) + tsVorh = dat_file[-16:-4] + return dat_file,tsVorh + +def findNr4EiderSpAP(pegelnames:TextIOWrapper)-> None: + """Überreste vom LFU Praktikum. + Prüft ob DE__664P noch die Nummer für Eidersperrwerk AP ist. Wenn nicht wird das Programm beendet. + + Args: + pegelnames (TextIOWrapper): TextIOWrapper für die Pegelnamen Datei + """ + # Überreste Praktikant + # suche Zeile: DE__664P 9530010 Eider-Sperrwerk, Außenpegel + for row in pegelnames.readlines(): + if 'DE__664P' in row: + if 'Eider-Sperrwerk, Außenpegel' in row: + print(" Name-Nummer korrekt ") + return + else: + print("ERROR: Die Nummer von Eidersperrwerk AP hat anscheindend gewechselt") + logger.error("Die Nummer von Eidersperrwerk AP hat anscheindend gewechselt und ist nicht mehr DE__664P 9530010") + quit() + logger.error("Die Nummer von Eidersperrwerk AP DE__664P 9530010 wurde nicht gefunden") + quit() + +def read_dat_file(zip_file:Path) -> pd.DataFrame: + """Opens and extracts the data from the MOS.zip file, and returns a DataFrame with the relevant data from dat_file. + tsVorh is used for the timestamp column of the forecast. + + Args: + zip_file (Path): Path to the MOS.zip file + + Returns: + pd.DataFrame: _description_ + """ + with ZipFile(zip_file,'r') as my_zip: + with my_zip.open("MOS/Pegel-namen-nummern.txt") as pegelnames: + pegelnames_str = TextIOWrapper(pegelnames,encoding="cp1252") #ANSI + findNr4EiderSpAP(pegelnames_str) + + dat_file = next(filter(lambda x: x.startswith("MOS/pgl/pgl_664P_"),my_zip.namelist())) + #tsVorh = dat_file[-16:-4] + + with my_zip.open(dat_file) as csv_file: + csv_file_str = TextIOWrapper(csv_file) + + df = pd.read_fwf(filepath_or_buffer=csv_file_str, + colspecs=[(7,19),(23,26),(26,28),(40,43),(43,48),(48,53),(53,58),(58,63),(63,68),(68,73)], + header=None,skiprows=1,parse_dates=[0],date_format="%Y%m%d%H%M") + df.columns= ["timestamp","Vorhersagezeitpunkt","Differenzzeit","gztmn","stau","r1","r2","r3","r4","r5"] + df["member"] = 0 + stau_cols = ["stau","r1","r2","r3","r4","r5"] + #Take the first valid fallback value or np.nan if all are invalid + df['pegel'] = df['gztmn'] + df[stau_cols].apply(lambda row: next((x for x in row if x >= -9900),np.nan), axis=1) + df["forecast"]= df["timestamp"] + pd.to_timedelta(df["Vorhersagezeitpunkt"],unit="h") + pd.to_timedelta(df["Differenzzeit"],unit="m") + pd.to_timedelta(1,unit="h") + #df["timestamp"] = tsVorh #TODO @Ralf diese Zeile finde ich etwas fragwürdig. Ich vermute dass diese Zeit falsch und die IN der datei richtig ist. + df = df[["timestamp","forecast","member","pegel"]] + return df + +def write_zrxp(target_file:Path,df:pd.DataFrame,only_two=False): + """Writes the DataFrame to a .zrx file with the correct header. + Overwrites the file if it already exists. + + Args: + target_file (Path): Where the file should be saved + df (pd.DataFrame): DataFrame with zrx data + only_two (bool, optional): If True, only the cols for wiski are written. + """ + if target_file.exists(): + logger.info("Overwriting existing file %s", target_file) + target_file.unlink() + else: + logger.info("Saving file %s", target_file) + + with open(target_file , 'w', encoding='utf-8') as file: + file.write('#REXCHANGE9530010.W.BSH_VH|*|\n') + file.write('#RINVAL-777|*|\n') + if only_two: + file.write('#LAYOUT(timestamp,value)|*|\n') + df = df[["forecast","pegel"]] + else: + file.write('#LAYOUT(timestamp,forecast,member,value)|*|\n') + + df.to_csv(path_or_buf =target_file , + header = False, + index=False, + mode='a', + sep = ' ', + date_format = '%Y%m%d%H%M') + +def make_target_file_name(base_dir:Path,tsVorh) -> Path: + """Creates the target file name for the .zrx file from the base_dir and the timestamp of the forecast. + + Args: + base_dir (Path): Base directory + tsVorh (_type_): String representation of the timestamp of the making of the forecast. (At least we pretend it is) + + Returns: + Path: The complete path to the target file + """ + if isinstance(tsVorh,str): + return base_dir / (tsVorh + "_9530010_EidersperrwerkAP_Vh.zrx") + return base_dir / (tsVorh.strftime("%Y%m%d%H%M") + "_9530010_EidersperrwerkAP_Vh.zrx") + +def main(): + """Main function of the script. Reads the data from the MOS.zip file, extracts the relevant data and writes it .zrx files.""" + + parsed_args = parse_args() # configs/icon_exp_ensemble.yaml + config = get_config(parsed_args) + + if "db_params" in config: + try: + logger.info("Trying to connect to DB") + if "config_dir" in config["db_params"]: + logger.info( + "Initiating Thick mode with executable %s", + config["db_params"]["config_dir"], + ) + oracledb.init_oracle_client(lib_dir=config["db_params"]["config_dir"]) + else: + logger.info("Initiating Thin mode") + con = oracledb.connect(**config["db_params"]) + except oracledb.OperationalError as e: + logger.error("Could not connect to DB: %s", e) + con = None + else: + con = None + logger.info("No DB connection") + setup_logging(con) + + + #dat_file,tsVorh = get_dat_file_name(config["source_file"]) + #target_file = make_target_file_name(config["target_folder_wiski"],tsVorh) + + df = read_dat_file(config["source_file"]) + target_file = make_target_file_name(config["target_folder_wiski"],df["timestamp"][0]) + write_zrxp(target_file,df,only_two=True) + + #Resample and cut data to fit the 3h interval, hourly sample rate used by ICON + df2 = df[["forecast","pegel"]].resample("1h",on="forecast").first() + df2["forecast"] = df2.index + offset = get_offset(df2["forecast"].iloc[0].hour) + df2 = df2.iloc[offset:] + df2.loc[:,"timestamp"] = df2["forecast"].iloc[0] - pd.Timedelta("1h") + df2["member"] = 0 + df2 = df2[["timestamp","forecast","member","pegel"]] + + #Use the one prediction as forecasts for the next the possible predict_db runs + #If another bsh forecast is made before the next predict_db run, the forecast will be overwritten in the database. + for offset in [0,3,6]: + df3 = df2.copy() + df3["timestamp"] = df3["timestamp"] + pd.Timedelta(hours=offset) + target_file = make_target_file_name(config["target_folder"],df3["timestamp"].iloc[0]) + write_zrxp(target_file,df3[offset:offset+48]) + + + + +if __name__ == "__main__": + main() diff --git a/src/dashboard/constants.py b/src/dashboard/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa1f61f34bae35ad4b461529ad869eaff95fef8 --- /dev/null +++ b/src/dashboard/constants.py @@ -0,0 +1,11 @@ +from datetime import timedelta + +NUM_RECENT_FORECASTS = 5 +NUM_RECENT_LOGS = 10 +BUFFER_TIME = timedelta(hours=2,minutes=30) +PERIODS_EXT_FORECAST_TABLE = 10 + + +# Wie viele Tage zurück in die Vergangenheit schauen, um zu prüfen, ob ein Eingangswert in das Modelll externe Vorhersagen hat. +# Ginge auch ohne, das wäre dann aber deutlich ineffizienter. Falls die externen Vorhersagen zuverlässig kommen lässt sich das auch auf wenige Tage reduzieren. +LOOKBACK_EXTERNAL_FORECAST_EXISTENCE = 100 \ No newline at end of file diff --git a/src/dashboard/layout_helper.py b/src/dashboard/layout_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..097c3aaa1f02125c629599031f0e783ce5d67a9b --- /dev/null +++ b/src/dashboard/layout_helper.py @@ -0,0 +1,23 @@ +from dash import html + +def create_collapsible_section(title, content, is_open=True): + return html.Div([ + html.Div([ + html.H3(title, style={'display': 'inline-block', 'marginRight': '10px'}), + html.Button( + '▼' if is_open else '▶', + id={'type': 'collapse-button', 'section': title}, + style={ + 'border': 'none', + 'background': 'none', + 'fontSize': '20px', + 'cursor': 'pointer' + } + ) + ], style={'marginBottom': '10px'}), + html.Div( + content, + id={'type': 'collapse-content', 'section': title}, + style={'display': 'block' if is_open else 'none'} + ) + ]) \ No newline at end of file diff --git a/src/dashboard/styles.py b/src/dashboard/styles.py new file mode 100644 index 0000000000000000000000000000000000000000..93c40ba77912b180b35dd7342f9358cc7d46db5c --- /dev/null +++ b/src/dashboard/styles.py @@ -0,0 +1,215 @@ +from typing import Dict, List, Any +import plotly.express as px + +# Common colors (using Bootstrap colors for consistency) +COLORS = { + 'success': '#198754', + 'danger': '#dc3545', + 'light_gray': '#f4f4f4', + 'border': 'rgba(0,0,0,.125)', + 'background': '#f8f9fa', + 'warning_bg': '#fff3e0', + 'error_bg': '#ffebee', + 'warning' : '#ef6c00', + 'error': '#c62828', +} + + +COLOR_SET = px.colors.qualitative.Set3 + +# Table Styles (This is messy, the conditional style data is not the same for each table) +TABLE_STYLE = { + 'style_table': { + 'overflowX': 'auto', + 'borderRadius': '4px', + 'border': f'1px solid {COLORS["border"]}' + }, + 'style_header': { + 'backgroundColor': COLORS['background'], + 'fontWeight': '600', + 'textAlign': 'left', + 'padding': '12px 16px', + 'borderBottom': f'1px solid {COLORS["border"]}' + }, + 'style_cell': { + 'textAlign': 'left', + 'padding': '12px 16px', + 'fontSize': '14px' + }, + 'style_data_conditional': [ + { + 'if': {'row_index': 'odd'}, + 'backgroundColor': 'rgb(248, 249, 250)' + }, + { + 'if': {'filter_query': '{has_current_forecast} = "✓"', "column_id": "has_current_forecast"}, + 'color': COLORS['success'] + }, + { + 'if': {'filter_query': '{has_current_forecast} = "✗"', "column_id": "has_current_forecast"}, + 'color': COLORS['danger'] + }, + { + 'if': {'filter_query': '{level} = "ERROR"'}, + 'backgroundColor': COLORS['error_bg'], + 'color': COLORS['error'] + }, + { + 'if': {'filter_query': '{level} = "WARNING"'}, + 'backgroundColor': COLORS['warning_bg'], + 'color': COLORS['warning'] + } + ] +} + +TAB_STYLE = { + "active_label_class_name": 'fw-bold', + "label_style": { + 'color': 'rgba(255, 255, 255, 0.9)', + 'padding': '1rem 1.5rem' + }, + "active_label_style": { + 'color': 'white', + 'background': 'rgba(255, 255, 255, 0.1)' + } + } + + +# Historical Table Style +HISTORICAL_TABLE_STYLE = { + 'style_table': { + 'maxHeight': '1200px', + 'maxWidth': '1600px', + 'overflowY': 'auto', + 'width': '100%' + }, + 'style_cell': { + 'textAlign': 'right', + 'padding': '5px', + 'minWidth': '100px', + 'whiteSpace': 'normal', + 'fontSize': '12px' + }, + 'style_header': { + 'backgroundColor': COLORS['light_gray'], + 'fontWeight': 'bold', + 'textAlign': 'center' + } +} + +# Forecast Status Table Style +FORECAST_STATUS_TABLE_STYLE = { + 'style_table': { + 'overflowX': 'auto', + 'width': '100%' + }, + 'style_cell': { + 'textAlign': 'center', + 'padding': '5px', + 'minWidth': '100px', + 'fontSize': '12px' + }, + 'style_header': { + 'backgroundColor': COLORS['light_gray'], + 'fontWeight': 'bold', + 'textAlign': 'center' + } +} + +# Plot Layouts +STATUS_SUMMARY_LAYOUT = { + 'height': 300, + 'margin': dict(l=30, r=30, t=50, b=30), + 'showlegend': True +} + +MODEL_FORECASTS_LAYOUT = { + 'height': 600, + 'xaxis_title': 'Time', + 'yaxis_title': 'Gauge [cm]', + 'legend': dict( + yanchor="top", + y=0.99, + xanchor="left", + x=1.05 + ) +} + +HISTORICAL_PLOT_LAYOUT = { + 'xaxis_title': 'Time', + 'yaxis_title': 'Value', + 'height': 600, + 'showlegend': True, + 'legend': dict( + yanchor="top", + y=0.99, + xanchor="left", + x=1.05 + ), + 'margin': dict(r=150) +} + +INPUT_FORECASTS_LAYOUT = { + 'showlegend': True, + 'legend': dict( + yanchor="top", + y=0.99, + xanchor="left", + x=1.05, + groupclick="togglegroup", + itemsizing='constant', + tracegroupgap=5 + ), + 'margin': dict(r=150) +} + +# Plot Colors and Traces +FORECAST_COLORS = { + 'historical': 'black', + 'historical_width': 2 +} + +def get_conditional_styles_for_columns(columns: List[str], timestamp_col: str = 'tstamp') -> List[Dict[str, Any]]: + """Generate conditional styles for columns with blank values.""" + styles = [ + { + 'if': { + 'filter_query': '{{{}}} is blank'.format(col), + 'column_id': col + }, + 'backgroundColor': '#ffebee', + 'color': '#c62828' + } for col in columns if col != timestamp_col + ] + + styles.append({ + 'if': {'column_id': timestamp_col}, + 'textAlign': 'left', + 'minWidth': '150px' + }) + + return styles + +def get_forecast_status_conditional_styles(sensor_names: List[str]) -> List[Dict[str, Any]]: + """Generate conditional styles for forecast status table.""" + styles = [] + for col in sensor_names: + styles.extend([ + { + 'if': { + 'filter_query': '{{{col}}} = "Missing"'.format(col=col), + 'column_id': col + }, + 'backgroundColor': '#ffebee', + 'color': '#c62828' + }, + { + 'if': { + 'filter_query': '{{{col}}} = "OK"'.format(col=col), + 'column_id': col + }, + 'backgroundColor': '#e8f5e9', + 'color': '#2e7d32' + } + ]) + return styles \ No newline at end of file diff --git a/src/dashboard/tables_and_plots.py b/src/dashboard/tables_and_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..7348ee462fbee131cd035007ae6c79e80063500d --- /dev/null +++ b/src/dashboard/tables_and_plots.py @@ -0,0 +1,437 @@ +from datetime import datetime, timedelta + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from dash import dash_table, html +from plotly.subplots import make_subplots +import dash_bootstrap_components as dbc + +from dashboard.styles import (FORECAST_STATUS_TABLE_STYLE, + HISTORICAL_TABLE_STYLE, + STATUS_SUMMARY_LAYOUT, + TABLE_STYLE, + COLORS, + COLOR_SET, + get_conditional_styles_for_columns, + get_forecast_status_conditional_styles) +from dashboard.constants import PERIODS_EXT_FORECAST_TABLE + +def create_status_summary_chart(status): + """Creates a simple pie chart showing the percentage of models with current forecasts.""" + total_models = len(status['model_status']) + models_with_forecast = sum(1 for model in status['model_status'] if model['has_current_forecast']) + models_without_forecast = total_models - models_with_forecast + + fig = go.Figure(data=[go.Pie( + labels=['Current Forecast', 'Missing Forecast'], + values=[models_with_forecast, models_without_forecast], + hole=0.4, + marker_colors=[COLORS['success'], COLORS['danger']], + textinfo='percent', + hovertemplate="Status: %{label}<br>Count: %{value}<br>Percentage: %{percent}<extra></extra>" + )]) + + fig.update_layout( + title={ + 'text': f'Model Status ({models_with_forecast} of {total_models} models have current forecasts)', + 'y': 0.95 + }, + **STATUS_SUMMARY_LAYOUT + ) + + return fig + + + +def create_model_forecasts_plot(df_forecasts, df_historical,df_inp_fcst): + """Create a plot showing recent model forecasts alongside historical data""" + sensor_name = df_forecasts["sensor_name"][0] + start_times = sorted(set(df_forecasts["tstamp"]) | set(df_inp_fcst["tstamp"])) + start_time_colors = {t: COLOR_SET[i % len(COLOR_SET)] for i, t in enumerate(start_times)} + + + fig = go.Figure() + #fig = make_subplots(specs=[[{"secondary_y": True}]]) + + if not df_historical.empty: + df_measured = df_historical[start_times[0]-timedelta(days=3):] + # Add historical data + fig.add_trace(go.Scatter( + x=df_measured.index, + y=df_measured[sensor_name], + name=f"Messwerte - {sensor_name}", + line=dict(color='black', width=2), + mode='lines' + )) + + df_forecasts = df_forecasts.round(1) + members = df_forecasts['member'].unique() + show_legend_check = set() + for member in members: + member_data = df_forecasts[df_forecasts['member'] == member] + for _, row in filter(lambda x: x[1]["h1"] is not None,member_data.iterrows()): + start_time = row['tstamp'] + if start_time in show_legend_check: + show_legend = False + else: + show_legend = True + show_legend_check.add(start_time) + + legend_group = f'{start_time.strftime("%Y-%m-%d %H:%M")}' + + fig.add_trace( + go.Scatter( + x=[start_time + timedelta(hours=i) for i in range(1, 49)], + y=[row[f'h{i}'] for i in range(1, 49)], + name=legend_group, + legendgroup=legend_group, + showlegend=show_legend, + line=dict( + color=start_time_colors[start_time], + width=3 if member == 0 else 1, + ), + hovertemplate=( + 'Time: %{x}<br>' + 'Value: %{y:.2f}<br>' + f'Member: {member}<br>' + f'Start: {start_time}<extra></extra>' + ), + visible= True if start_time == df_forecasts["tstamp"].max() else 'legendonly' + ), + ) + temp_layout_dict = {"yaxis": {"title" : {"text":"Pegel [cm]"}}} + + fig.update_layout(**temp_layout_dict) + + + ##### + # Add input forecast + if True: + sensors = df_inp_fcst['sensor_name'].unique() + show_legend_check = set() + + # For each sensor + for sensor_idx, sensor in enumerate(sensors, 1): + temp_layout_dict = {f"yaxis{sensor_idx+1}": { + "title" : {"text":"Pegel [cm]"}, + "anchor" : "free", + "overlaying" : "y", + "autoshift":True, + } + } + + fig.update_layout(**temp_layout_dict) + + + sensor_data = df_inp_fcst[df_inp_fcst['sensor_name'] == sensor] + members = sensor_data['member'].unique() + + + # Get unique forecast start times for color assignment + # For each member + for member in members: + member_data = sensor_data[sensor_data['member'] == member] + + # For each forecast (row) + for _, row in member_data.iterrows(): + start_time = row['tstamp'] + #legend_group = f'{sensor} {start_time.strftime("%Y%m%d%H%M")}' + legend_group = f'{start_time.strftime("%Y-%m-%d %H:%M")} {sensor}' + if start_time in show_legend_check: + show_legend = False + else: + show_legend = True + show_legend_check.add(start_time) + + timestamps = [start_time + timedelta(hours=i) for i in range(1, 49)] + values = [row[f'h{i}'] for i in range(1, 49)] + + fig.add_trace( + go.Scatter( + x=timestamps, + y=values, + name=legend_group, + legendgroup=legend_group, + showlegend=show_legend, + line=dict( + color=start_time_colors[start_time], + width=3 if member == 0 else 1, + dash='solid' if member == 0 else 'dot' + ), + hovertemplate=( + 'Time: %{x}<br>' + 'Value: %{y:.2f}<br>' + f'Member: {member}<br>' + f'Start: {start_time}<extra></extra>' + ), + visible= True if start_time == df_forecasts["tstamp"].max() else 'legendonly', + yaxis=f"y{sensor_idx+1}" + ) + ) + + fig.update_layout( + #height=600, + #title='Model Forecasts', + #xaxis_title='Time', + #yaxis_title='', + #hovermode='x unified', + legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=1.05, + #bgcolor='rgba(255, 255, 255, 0.8)' + ) + ) + return fig + + + +def create_log_table(logs_data): + """ + Creates a configured DataTable for logging display + + Args: + logs_data: List of dictionaries containing log data + + Returns: + dash_table.DataTable: Configured table for log display + """ + columns = [ + {'name': 'Time', 'id': 'timestamp'}, + {'name': 'Level', 'id': 'level'}, + {'name': 'Message', 'id': 'message'}, + {'name': 'Exception', 'id': 'exception'} + ] + + return dash_table.DataTable( + columns=columns, + data=logs_data, + **TABLE_STYLE + #style_data_conditional=style_data_conditional, + #style_table=style_table, + #style_cell=style_cell + ) + + + + +def create_historical_plot(df_historical, model_name): + """ + Creates a plotly figure for historical sensor data + + Args: + df_historical: DataFrame with sensor measurements + model_name: String name of the model + + Returns: + plotly.graph_objects.Figure: Configured plot + """ + fig = go.Figure() + + # Add a trace for each sensor + for column in df_historical.columns: + fig.add_trace( + go.Scatter( + x=df_historical.index, + y=df_historical[column], + name=column, + mode='lines', + hovertemplate='%{y:.2f}<extra>%{x}</extra>' + ) + ) + + # Update layout + fig.update_layout( + #title=f'Sensor Measurements - Last 144 Hours for {model_name}', + xaxis_title='Time', + yaxis_title='Value', + height=600, + showlegend=True, + legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=1.05 + ), + margin=dict(r=150) + ) + + return fig + +def create_historical_table(df_historical): + """Creates a formatted DataTable for historical sensor data.""" + df_table = df_historical.reset_index() + df_table['tstamp'] = df_table['tstamp'].dt.strftime('%Y-%m-%d %H:%M') + + columns = [{'name': col, 'id': col} for col in df_table.columns] + style_data_conditional = get_conditional_styles_for_columns(df_table.columns) + + #TODO move buttons etc. to app.py + hist_table = dash_table.DataTable( + id='historical-table', + columns=columns, + data=df_table.to_dict('records'), + style_data_conditional=style_data_conditional, + fixed_columns={'headers': True, 'data': 1}, + sort_action='native', + sort_mode='single', + **HISTORICAL_TABLE_STYLE + ) + + #return dbc.Card([ + # dbc.CardHeader([ + # dbc.Row([ + # dbc.Col( + # dbc.Button( + # "Show All Data", + # id='toggle-missing-values', + # color="primary", + # className="mb-3", + # n_clicks=0 + # ), + # width="auto" + # ) + # ]) + # ]), + # dbc.CardBody(hist_table) + #]) + return html.Div([ + dbc.Row([ + dbc.Col( + dbc.Button( + "Show All Data", + id='toggle-missing-values', + color="primary", + className="mb-3", + n_clicks=0 + ), + width="auto" + ) + ]), + hist_table + + ]) + + #return hist_table + + + +def create_inp_forecast_status_table(df_forecast,ext_forecast_names): + """ + Creates a status table showing availability of forecasts for each sensor at 3-hour intervals + + Args: + df_forecast: DataFrame with columns tstamp, sensor_name containing forecast data + ext_forecast_names: List of external forecast names + + Returns: + dash.html.Div: Div containing configured DataTable + """ + # Get unique sensor names + #sensor_names = sorted(df_forecast['sensor_name'].unique()) + + # Create index of last 48 hours at 3-hour intervals + last_required_hour = datetime.now().replace(minute=0, second=0, microsecond=0) + while last_required_hour.hour % 3 != 0: + last_required_hour -= timedelta(hours=1) + + + time_range = pd.date_range( + end=last_required_hour, + periods=PERIODS_EXT_FORECAST_TABLE, # 48 hours / 3 + 1 + freq='3h' + ) + + # Initialize result DataFrame with NaN + #status_df = pd.DataFrame(index=time_range, columns=sensor_names) + status_df = pd.DataFrame(index=time_range, columns=ext_forecast_names) + + # For each sensor and timestamp, check if data exists + #for sensor in sensor_names: + for sensor in ext_forecast_names: + sensor_data = df_forecast[df_forecast['sensor_name'] == sensor] + for timestamp in time_range: + #TODO ENSEMBLES + has_data = any( + sensor_data['tstamp'] == timestamp + #(sensor_data['tstamp'] <= timestamp) & + #(sensor_data['tstamp'] + timedelta(hours=48) >= timestamp) + ) + status_df.loc[timestamp, sensor] = 'OK' if has_data else 'Missing' + + # Reset index to make timestamp a column + status_df = status_df.reset_index() + status_df['index'] = status_df['index'].dt.strftime('%Y-%m-%d %H:%M') + + # Configure table styles + style_data_conditional = get_forecast_status_conditional_styles(ext_forecast_names) + + return html.Div( + dash_table.DataTable( + id='forecast-status-table', + columns=[ + {'name': 'Timestamp', 'id': 'index'}, + *[{'name': col, 'id': col} for col in ext_forecast_names] + ], + data=status_df.to_dict('records'), + style_data_conditional=style_data_conditional, + fixed_columns={'headers': True, 'data': 1}, + sort_action='native', + sort_mode='single', + **FORECAST_STATUS_TABLE_STYLE + ) + ) + + +def create_metrics_plots(df_metrics, metric_name): + """Create plots for model metrics""" + + fig = make_subplots( + rows=2, + cols=1, + subplot_titles=["Gesamt", "Flutbereich (95. Percentil)"], + vertical_spacing=0.1 + ) + + for col in df_metrics: + row = 2 if 'flood' in col else 1 + set_name = col.split("_")[0] + if set_name == 'train': + color = COLOR_SET[0] + elif set_name == 'val': + color = COLOR_SET[1] + elif set_name == 'test': + color = COLOR_SET[2] + else: + color = COLOR_SET[3] + + + fig.add_trace( + go.Scatter( + x=df_metrics.index, + y=df_metrics[col], + name=set_name, + legendgroup=set_name, + showlegend= row==1, + mode='lines+markers', + line={"color":color}, + ), + row=row, + col=1 + ) + + fig.update_layout( + xaxis_title='Vorhersagehorizont', + yaxis_title=metric_name.upper(), + height=600, + showlegend=True, + legend={ + "yanchor": 'top', + "y": 0.99, + "xanchor": 'left', + "x": 1.05}, + margin={"r":50} + ) + return fig diff --git a/src/scripts/icon_export_v3.py b/src/icon_export_v3.py similarity index 90% rename from src/scripts/icon_export_v3.py rename to src/icon_export_v3.py index f7d765420c58bba913f4dfd76d475236cdfb6c3a..931fc516a2b2bee65e6f7949cd75c82235204d95 100755 --- a/src/scripts/icon_export_v3.py +++ b/src/icon_export_v3.py @@ -11,44 +11,58 @@ import subprocess as sp import time from pathlib import Path from typing import List, Tuple - import numpy as np import pandas as pd import pygrib from shapely.geometry import Point from shapely.geometry.polygon import Polygon -from sympy import elliptic_f import yaml -from io import BytesIO +import oracledb +from utils.db_tools.db_logging import OracleDBHandler logger = logging.getLogger("icon_export") logger.setLevel(logging.DEBUG) -def setup_logging(log_file: Path, debug: bool, max_bytes: int = 50000) -> None: +def setup_logging(log_file: Path, debug: bool, max_bytes: int = 50000,con=None) -> None: """Set up rotating log files Args: log_file (Path): Where to log to debug (bool): Wheather to log debug level messages max_bytes (int, optional): maximum filesize for the logfile. Defaults to 50000. + con ([type], optional): Oracle connection. Defaults to None. """ + ####### DB Logging + old_factory = logging.getLogRecordFactory() + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.gauge_id = kwargs.pop("gauge_id", None) + return record + logging.setLogRecordFactory(record_factory) + #### log_file.parent.mkdir(parents=True, exist_ok=True) - formatter = logging.Formatter( - fmt="%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S" + log_formatter = logging.Formatter( + "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) - logHandlerFile = logging.handlers.RotatingFileHandler( log_file, maxBytes=max_bytes, backupCount=1 ) logHandlerFile.setLevel(logging.DEBUG if debug else logging.INFO) - logHandlerFile.setFormatter(formatter) + logHandlerFile.setFormatter(log_formatter) logger.addHandler(logHandlerFile) logHandlerStream = logging.StreamHandler() logHandlerStream.setLevel(logging.DEBUG if debug else logging.INFO) - logHandlerStream.setFormatter(formatter) + logHandlerStream.setFormatter(log_formatter) logger.addHandler(logHandlerStream) + ####### DB Logging + if con is not None: + db_handler = OracleDBHandler(con) + db_handler.setFormatter(log_formatter) + db_handler.setLevel(logging.DEBUG if debug else logging.INFO) + logger.addHandler(db_handler) + def parse_args() -> argparse.Namespace: @@ -308,12 +322,12 @@ class ICON_Exporter: #Select the smallest file try: - grib_file = sorted( + grib_file = sorted( filter(lambda x: x.suffix == ".grib2", self.source_folder.iterdir()), key=lambda x: x.stat().st_size, )[0] - except IndexError: - raise FileNotFoundError("No .grib2 files found in source folder") + except IndexError as e: + raise FileNotFoundError("No .grib2 files found in source folder") from e lats, lons = pygrib.open(str(grib_file))[1].latlons() # pylint: disable=no-member logger.info( "Generating index files from %s, grid shape is %s", @@ -334,7 +348,7 @@ class ICON_Exporter: else: index_files = list(filter(lambda x: x.suffix == ".index", self.index_folder.iterdir())) - logger.info("Loading index files %s", index_files) + logger.info("Loading index files from folder %s", self.index_folder) self.indexes = [pd.read_csv(index_file, index_col=0).values.flatten() for index_file in index_files] self.names = list(map(get_area_name, index_files)) @@ -406,8 +420,15 @@ class ICON_Exporter: source_base = str(tmp_folder / (source_file.stem + "_")) self.extract_ensemble(source_base) for temp_file in tmp_folder.iterdir(): - temp_file.unlink() - tmp_folder.rmdir() + try: + temp_file.unlink() + except Exception as e: + logger.error("Could not delete file %s: %s", temp_file, e) + try: + tmp_folder.rmdir() + except OSError as e: + logger.error("Could not delete folder %s: %s", tmp_folder, e) + def extract_ensemble(self,source_base: str) -> None: """Extracts the predicted mean rainfall per hour for each ensemble member and area/catchment and exports them into zrx files. @@ -441,9 +462,6 @@ class ICON_Exporter: for name, cur_dict in dict_dict.items(): dict_to_zrx(cur_dict, self.target_folder, source_file, name, member=i) - - #Remove the grib2 file created by grib_copy - source_file.unlink() def process_hourly(self) -> None: @@ -518,7 +536,27 @@ def main() -> None: parsed_args = parse_args() # configs/icon_exp_ensemble.yaml config = get_config(parsed_args) - setup_logging(config["log_file"], config["debug"]) + if "db_params" in config: + try: + if "config_dir" in config["db_params"]: + logging.info( + "Initiating Thick mode with executable %s", + config["db_params"]["config_dir"], + ) + oracledb.init_oracle_client(lib_dir=config["db_params"]["config_dir"]) + else: + logging.info("Initiating Thin mode") + con = oracledb.connect(**config["db_params"]) + except oracledb.OperationalError as e: + logger.error("Could not connect to DB: %s", e) + con = None + else: + con = None + logging.info("No DB connection") + + #con = oracledb.connect(**main_config["db_params"]) + + setup_logging(config["log_file"], config["debug"],con=con) logger.info( "Start of this execution======================================================================" ) diff --git a/src/import_model.py b/src/import_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cf40609089f652fad3af821c8a707ad0aa81d205 --- /dev/null +++ b/src/import_model.py @@ -0,0 +1,327 @@ +"""Short script to add models to the database.""" +import argparse +import logging +import sys +from pathlib import Path +import pandas as pd +import oracledb +import yaml +from sqlalchemy import create_engine, func, select +from sqlalchemy.orm import Session +from tbparse import SummaryReader + +import utils.helpers as hp +from utils.db_tools.db_logging import OracleDBHandler +from utils.db_tools.orm_classes import InputForecastsMeta, Modell,ModellSensor, Sensor, Metric + + +def parse_args() -> argparse.Namespace: + """Parse all the arguments and provides some help in the command line""" + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Add models to the database. Needs a YAML File" + ) + + parser.add_argument( + "yaml_path", + metavar="yaml_path", + type=Path, + help="""The path to your config file. Example config file: + db_params: + user: c##mspils + #Passwort ist für eine lokale Datenbank ohne verbindung zum internet. Kann ruhig online stehen. + password: cobalt_deviancy + dsn: localhost/XE + models: + - gauge: 114547,S,60m.Cmd + model_name : mein_model + model_folder: ../../models_torch/tarp_version_49/ + columns: + - 4466,SHum,vwsl + - 4466,SHum,bfwls + - 4466,AT,h.Cmd + - 114435,S,60m.Cmd + - 114050,S,60m.Cmd + - 114547,S,60m.Cmd + - 114547,Precip,h.Cmd + external_fcst: + - 114547,Precip,h.Cmd + """ + ) + return parser.parse_args() + + +def get_configs(yaml_path: Path): + """ + Retrieves the configurations from a YAML file. + Also converts the paths to Path objects and the dates to datetime objects. + + Args: + passed_args (argparse.Namespace): The command-line arguments. + + Returns: + Tuple[dict, List[dict]]: A tuple containing the main configuration dictionary + and a list of gauge configuration dictionaries. + """ + + with open(yaml_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) + + for model_config in config["models"]: + model_config["model_folder"] = Path(model_config["model_folder"]) + assert all([isinstance(c,str) for c in model_config["columns"]]) + assert isinstance(model_config["gauge"],str) + #assert all([isinstance(k,str) and isinstance(v,str) for k,v in gauge_config["external_fcst"].items()]) + assert all([isinstance(c,str) for c in model_config["external_fcst"]]) + return config + +def _prepare_logging(con) -> None: + old_factory = logging.getLogRecordFactory() + def record_factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.gauge_id = kwargs.pop("gauge_id", None) + return record + logging.setLogRecordFactory(record_factory) + + log_formatter = logging.Formatter( + "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + db_handler = OracleDBHandler(con) + db_handler.setFormatter(log_formatter) + + db_handler.setLevel(logging.INFO) + logging.getLogger().addHandler(db_handler) + logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:]) + + +class ModelImporter: + """Class for handling the import of models into the database. + """ + def __init__(self, con: oracledb.Connection): + self.con = con + self.engine = create_engine("oracle+oracledb://", creator=lambda: con) + + def process_model(self, session: Session, model_config: dict): + """Process a model configuration and add it to the database. This includes adding the model, sensors,model sensors and input_forecast_meta objects. + + Args: + session (Session): orm session + model_config (dict): model configuration dictionary.Example: + + "gauge": "mstr2,parm_sh,ts-sh" + "model_name": "mein_model_unfug" + "model_folder": "../../models_torch/tarp_version_49/" + "columns": ["a,b,c", + "mstr2,parm_sh,ts-sh"] + "external_fcst": ["a,b,c"] + """ + self._validate_config(model_config) + model_hparams = hp.load_settings_model(model_config["model_folder"]) + + # Check if model already exists in database + if (old_model := self._get_model_maybe(session, model_config["model_name"])) is not None: + logging.warning("Model with name %s already exists in database,skipping",old_model) + #TODO updating requires also updating the ModellSensor table, otherwise we get 2 rows with the same modell_id sensor_name combination. + #if mi._get_user_confirmation(f"Press y to confirm changing the model {old_model} or n to skip"): + # logging.info(f"Updating model {old_model}") + # sensor_obj = mi._parse_sensor_maybe_add(session,target_sensor_name) + # model_id = old_model.id + # mi._update_model(old_model,sensor_obj,model_config,model_hparams) + #else: + # continue + return + else: + # Add Model to database + model_id = self._add_model(session,model_config,model_hparams) + + self._maybe_add_metrics(session,model_id,model_config,model_hparams) + + for i,col in enumerate(model_config["columns"]): + _ = self._parse_sensor_maybe_add(session,col) + inp_fcst_meta = self._maybe_add_inp_fcst_meta(session,model_config,col) + vhs_gebiet = None if inp_fcst_meta is None else inp_fcst_meta.vhs_gebiet + _ = self._maybe_add_modelsensor(session,model_id,col,vhs_gebiet,i) + + def _maybe_add_metrics(self,session,model_id,model_config,model_hparams): + #TODO error handling and updating old models with metrics + reader = SummaryReader(model_config["model_folder"],pivot=True) + df =reader.scalars + loss_cols = df.columns[df.columns.str.contains("loss")] + df = df.drop(["step","epoch"] + list(loss_cols),axis=1)[1:49] + + + metrics = [] + for col in df: + for i,row in enumerate(df[col]): + col_name = col + if col_name.startswith("hp/"): + col_name = col_name[3:] + + set_name = col_name.split("_")[0] + if set_name == "train": + start_time = pd.to_datetime(model_hparams["train_start"]) + end_time = pd.to_datetime(model_hparams["val_start"]) + elif set_name == "val": + start_time = pd.to_datetime(model_hparams["val_start"]) + end_time = pd.to_datetime(model_hparams["test_start"]) + elif set_name == "test": + start_time = pd.to_datetime(model_hparams["test_start"]) + + #TODO ungenauer Pfusch, besser end_time loggen + not_test_time = (pd.to_datetime(model_hparams["test_start"]) - pd.to_datetime(model_hparams["train_start"])) + not_test_share = model_hparams["train"]+model_hparams["val"] + test_share = 1-not_test_share + + end_time = start_time + (not_test_time/not_test_share*test_share) + end_time = end_time.round("h") + if "flood" in col_name: + tag = "flood" + else: + tag = "normal" + + metric = Metric(model_id=model_id,metric_name=col_name,timestep=i+1,start_time=start_time,end_time=end_time,value=row,tag=tag) + metrics.append(metric) + + session.add_all(metrics) + + def _get_next_model_id(self, session: Session) -> float: + """Get next available model ID""" + #TODO the database should use an IDENTITY column for the model table, then we could remove this + result = session.execute(select(func.max(Modell.id))).scalar() + return 1 if result is None else result + 1 + + def _get_sensor_maybe(self, session: Session, sensor_name: str): + """Check if model exists in database""" + stmt = select(Sensor).where(Sensor.sensor_name == sensor_name) + return session.scalars(stmt).first() + + def _get_model_maybe(self, session: Session, model_name: str) -> Modell: + """Check if model exists in database""" + stmt = select(Modell).where(Modell.modellname == model_name) + return session.scalars(stmt).first() + + def _get_user_confirmation(self, message: str) -> bool: + """Get user confirmation for action""" + response = input(f"{message} (y/n): ").lower() + return response == 'y' + + def _get_user_input(self, message: str) -> str: + """Get user input for action""" + response = input(f"{message}: ") + return response + + def _parse_sensor_maybe_add(self,session,sensor_name: str) -> Sensor: + """ Trie to get sensor from SENSOR table, if not found, tries to interpolate values from sensor name or user input. + Also creates and adds sensor to SENSOR table if not found. + """ + if (sensor_obj := self._get_sensor_maybe(session,sensor_name)) is None: + logging.info("Sensor %s not found in table SENSOR, attempting interpolation from string, adding to database",sensor_name) + try: + mstnr, type_short, ts_short = sensor_name.split(",") + except ValueError as e: + logging.error("Could not interpolate values from sensor name %s, expected format is 'mstnr, parameter_kurzname, TS Shortname' comma separated",sensor_name) + raise ValueError from e + + sensor_obj = Sensor(mstnr=mstnr,parameter_kurzname=type_short,ts_shortname=ts_short,beschreibung="Automatisch generierter Eintrag",sensor_name=sensor_name) + session.add(sensor_obj) + logging.info(msg=f"Added Sensor {sensor_obj} to database") + return sensor_obj + + def _maybe_add_modelsensor(self,session,model_id,col,vhs_gebiet,ix) -> ModellSensor: + """ Creates a ModellSensor object and adds it to the database. + """ + model_sensor = ModellSensor( + modell_id = model_id, + sensor_name = col, + vhs_gebiet = vhs_gebiet, + ix = ix + ) + session.add(model_sensor) + return model_sensor + + def _maybe_add_inp_fcst_meta(self,session,model_config:dict,col:str) -> InputForecastsMeta: + """Get or add input forecast metadata to database""" + if col not in model_config["external_fcst"]: + return None + if (inp_fcst := session.scalars(select(InputForecastsMeta).where(InputForecastsMeta.sensor_name == col)).first()) is None: + logging.info("Input forecast %s not found in table INPUT_FORECASTS_META, adding to database",col) + vhs_gebiet = self._get_user_input(f"Enter the name (vhs_gebiet) of the area or station for the input forecast {col}") + ensemble_members = self._get_user_input(f"Enter the number of ensemble members for the input forecast {col}, 1 for deterministic, 21 for ICON, other values not supported") + ensemble_members = int(ensemble_members) + assert ensemble_members in [1,21], "Only 1 (deterministic) and 21 (ICON) ensemble + deterministic are supported" + + inp_fcst = InputForecastsMeta(sensor_name=col,vhs_gebiet=vhs_gebiet,ensemble_members=ensemble_members) + session.add(inp_fcst) + logging.info(msg=f"Added Input_Forecasts_Meta {inp_fcst} to database") + + return inp_fcst + + def _validate_config(self,model_config): + if not set(model_config["external_fcst"]) <= set(model_config["columns"]): + logging.error("Not all external forecasts are in columns") + raise ValueError + if model_config["gauge"] not in model_config["columns"]: + logging.error("Gauge not in columns") + raise ValueError("Gauge %s not in columns" % model_config["gauge"]) + + def _add_model(self,session,model_config,model_hparams) -> int: + target_sensor_name = model_config["gauge"] + sensor_obj = self._parse_sensor_maybe_add(session,target_sensor_name) + model_id = self._get_next_model_id(session) + model_obj = Modell(id=model_id, + mstnr=sensor_obj.mstnr, + modellname=model_config["model_name"], + modelldatei=str(model_config["model_folder"].resolve()), + aktiv=1, + kommentar=' \n'.join(model_hparams["scaler"].feature_names_in_), + in_size=model_hparams["in_size"], + target_sensor_name=target_sensor_name, + ) + session.add(model_obj) + logging.info("Added model %s to database",model_obj) + return model_obj.id + + + + def _update_model(self,old_model,sensor_obj,model_config,model_hparams): + #def update_log_lambda(name,old,new): kurz, aber hacky. Müsster außerdim mit __dict__ oder so angepasst werden + # if old != new: + # logging.info(f"Updating {old_model} {name} from {old} to {new}") + # old = new + if old_model.mstnr !=sensor_obj.mstnr: + logging.info(msg=f"Updating {old_model} mstnr to {sensor_obj.mstnr}") + old_model.mstnr =sensor_obj.mstnr + if old_model.modelldatei != (modelldatei:=str(model_config["model_folder"].resolve())): + logging.info(msg=f"Updating {old_model} modelldatei to {modelldatei}") + old_model.modelldatei=modelldatei + if old_model.kommentar != (kommentar:= ' \n'.join(model_hparams["scaler"].feature_names_in_)): + logging.info(msg=f"Updating {old_model} kommentar to {kommentar}") + old_model.kommentar=kommentar + if old_model.in_size != model_hparams["in_size"]: + logging.info(msg=f"Updating {old_model} in_size to {model_hparams['in_size']}") + old_model.in_size=model_hparams["in_size"] + if old_model.target_sensor_name != model_config["gauge"]: + logging.info(msg=f"Updating {old_model} target_sensor_name to { model_config["gauge"]}") + old_model.target_sensor_name= model_config["gauge"] + + +def main(): + """Main function to import models into the database""" + passed_args = parse_args() + + config = get_configs(passed_args.yaml_path) + if "config_dir" in config["db_params"]: + logging.info("Initiating Thick mode with executable %s",config["db_params"]["config_dir"],) + oracledb.init_oracle_client(lib_dir=config["db_params"]["config_dir"]) + else: + logging.info("Initiating Thin mode") + con = oracledb.connect(**config["db_params"]) + _prepare_logging(con) + + mi = ModelImporter(con) + with Session(mi.engine) as session: + for model_config in config["models"]: + mi.process_model(session,model_config) + session.commit() + +if __name__ == "__main__": + main() diff --git a/src/predict_database.py b/src/predict_database.py index ab30b028f4508aad912fbcf66d00f12d25ff2bdb..5ef0d5b27fd49f6f0a940b98ca6b69da81a90983 100644 --- a/src/predict_database.py +++ b/src/predict_database.py @@ -11,7 +11,8 @@ import oracledb import pandas as pd import yaml -from utils.db_tools import OracleDBHandler, OracleWaVoConnection +from utils.db_tools.db_tools import OracleWaVoConnection +from utils.db_tools.db_logging import OracleDBHandler def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]: @@ -34,10 +35,15 @@ def get_configs(passed_args: argparse.Namespace) -> Tuple[dict, List[dict]]: if isinstance(main_config["zrxp_folder"], str): main_config["zrxp_folder"] = [main_config["zrxp_folder"]] main_config["zrxp_folder"] = list(map(Path, main_config["zrxp_folder"])) + main_config["zrxp_out_folder"] = Path(main_config["zrxp_out_folder"]) main_config["sensor_folder"] = Path(main_config["sensor_folder"]) main_config["start"] = pd.to_datetime(main_config["start"]) if "end" in main_config: main_config["end"] = pd.to_datetime(main_config["end"]) + + if main_config.get("fake_external"): + assert main_config["range"] == True, "fake_external only works with range=True" + if passed_args.start is not None: main_config["start"] = passed_args.start main_config.pop("end", None) @@ -68,14 +74,14 @@ def _prepare_logging_2(con) -> None: return record logging.setLogRecordFactory(record_factory) - logFormatter = logging.Formatter( + log_formatter = logging.Formatter( "%(asctime)s;%(levelname)s;%(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) - dbHandler = OracleDBHandler(con) - dbHandler.setFormatter(logFormatter) + db_handler = OracleDBHandler(con) + db_handler.setFormatter(log_formatter) - dbHandler.setLevel(logging.INFO) - logging.getLogger().addHandler(dbHandler) + db_handler.setLevel(logging.INFO) + logging.getLogger().addHandler(db_handler) logging.info("Executing %s with parameters %s ", sys.argv[0], sys.argv[1:]) @@ -98,8 +104,8 @@ def parse_args() -> argparse.Namespace: help="The start date for the prediction. Format: YYYY-MM-DD HH:MM, overwrites start/end/range in the config file.", ) parser.add_argument( - "--zrxp", action="store_true", help="Save predictions as ZRXP files (not yet implemented)" - ) #TODO: Implement ZRXP + "--zrxp", action="store_true", help="Save predictions as ZRXP files" + ) return parser.parse_args() @@ -129,7 +135,6 @@ def main(passed_args) -> None: else: logging.info("Initiating Thin mode") connector = OracleWaVoConnection(con, main_config) - # connector.maybe_create_tables() # Create tables if they don't exist connector.maybe_update_tables() # Potentially load new data for gauge_config in gauge_configs: @@ -139,6 +144,5 @@ def main(passed_args) -> None: if __name__ == "__main__": - # TODO argparse einschränken auf bestimmte pegel? - args = parse_args() - main(args) # parse_args() + my_args = parse_args() + main(my_args) # parse_args() diff --git a/src/scripts/bsh_extractor_v2.py b/src/scripts/bsh_extractor_v2.py deleted file mode 100644 index 0a6e1c5843adc2547e9a1a01209a3c4f81e579ae..0000000000000000000000000000000000000000 --- a/src/scripts/bsh_extractor_v2.py +++ /dev/null @@ -1,171 +0,0 @@ -import argparse -from io import TextIOWrapper -from pathlib import Path -from typing import Tuple -from zipfile import ZipFile - -import pandas as pd - - -def parse_arguments() -> Path: - """Parse command line arguments. - - Returns: - Path: Directory with MOS.zip file. - """ - parser = argparse.ArgumentParser(description='Process root folder.') - parser.add_argument('base_dir', type=Path, help='Path to the root folder') - args = parser.parse_args() - return args.base_dir - -def get_offset(hour:int) -> int: - """ Calculates how many values to skip from the input hour to get to an hour divisible by 3 - - Args: - hour (int): Hour part of the timestamp for the first predicted value - - Returns: - int: The offset - """ - if hour % 3 == 0: - return 1 - if hour % 3 == 1: - return 0 - if hour % 3 == 2: - return 2 - - -def get_dat_file_name(zip_file_name:Path) -> Tuple[str, str]: - """Finds the complete Path to the Eidersperrwerk AP data file and the timestamp of when the forecast was made. - - Args: - zip_file_name (Path): Path to the MOS.zip file - - Returns: - Tuple[str, str]: The complete path to the data file and the timestamp of the forecast - """ - with ZipFile(zip_file_name,'r')as my_zip: - with my_zip.open("MOS/Pegel-namen-nummern.txt") as pegelnames: - pegelnames_str = TextIOWrapper(pegelnames,encoding="cp1252") #ANSI - findNr4EiderSpAP(pegelnames_str) - - dat_file = next(filter(lambda x: x.startswith("MOS/pgl/pgl_664P_"),my_zip.namelist())) - tsVorh = dat_file[-16:-4] - return dat_file,tsVorh - -def findNr4EiderSpAP(pegelnames:TextIOWrapper)-> None: - """Überreste vom LFU Praktikum. - Prüft ob DE__664P noch die Nummer für Eidersperrwerk AP ist. Wenn nicht wird das Programm beendet. - - Args: - pegelnames (TextIOWrapper): TextIOWrapper für die Pegelnamen Datei - """ - # Überreste Praktikant - # suche Zeile: DE__664P 9530010 Eider-Sperrwerk, Außenpegel - for row in pegelnames.readlines(): - if 'DE__664P' in row: - if 'Eider-Sperrwerk, Außenpegel' in row: - print(" Name-Nummer korrekt ") - return - else: - print("ERROR: Die Nummer von Eidersperrwerk AP hat anscheindend gewechselt") - quit() - quit() - -def read_dat_file(zip_file:Path,dat_file:str,tsVorh:str) -> pd.DataFrame: - """Opens and extracts the data from the MOS.zip file, and returns a DataFrame with the relevant data from dat_file. - tsVorh is used for the timestamp column of the forecast. - - Args: - zip_file (Path): Path to the MOS.zip file - dat_file (str): Path to the data file in the zip file - tsVorh (str): String representation of the timestamp of the making of the forecast - - Returns: - pd.DataFrame: _description_ - """ - with ZipFile(zip_file,'r') as my_zip: - with my_zip.open(dat_file) as csv_file: - csv_file_str = TextIOWrapper(csv_file) - - df = pd.read_fwf(filepath_or_buffer=csv_file_str,colspecs=[(7,19),(23,26),(26,28),(40,43),(45,48)],header=None,skiprows=1,parse_dates=[0],date_format="%Y%m%d%H%M") - df.columns= ["timestamp","stunden","minuten","gztmn","stau"] - df["member"] = 0 - df["pegel"] = df["gztmn"] + df["stau"] - df["forecast"]= df["timestamp"] + pd.to_timedelta(df["stunden"],unit="h") + pd.to_timedelta(df["minuten"],unit="m") + pd.to_timedelta(1,unit="h") - - df["timestamp"] = tsVorh #TODO @Ralf diese Zeile finde ich etwas fragwürdig. Ich vermute dass diese Zeit falsch und die IN der datei richtig ist. - - df = df[["timestamp","forecast","member","pegel"]] - return df - -def writeZrxp(target_file:Path,df:pd.DataFrame): - """Writes the DataFrame to a .zrx file with the correct header. - Overwrites the file if it already exists. - - Args: - target_file (Path): Where the file should be saved - df (pd.DataFrame): DataFrame with zrx data - """ - if target_file.exists(): - print(f"Overwriting existing file {target_file}") - target_file.unlink() - - with open(target_file , 'w', encoding='utf-8') as file: - file.write('#REXCHANGE9530010.W.BSH_VH|*|\n') - file.write('#RINVAL-777|*|\n') - file.write('#LAYOUT(timestamp,forecast,member,value)|*|\n') - - df.to_csv(path_or_buf =target_file , - header = False, - index=False, - mode='a', - sep = ' ', - date_format = '%Y%m%d%H%M') - -def make_target_file_name(base_dir:Path,tsVorh) -> Path: - """Creates the target file name for the .zrx file from the base_dir and the timestamp of the forecast. - - Args: - base_dir (Path): Base directory - tsVorh (_type_): String representation of the timestamp of the making of the forecast. (At least we pretend it is) - - Returns: - Path: The complete path to the target file - """ - if isinstance(tsVorh,str): - return base_dir / (tsVorh + "_9530010_EidersperrwerkAP_Vh.zrx") - return base_dir / (tsVorh.strftime("%Y%m%d%H%M") + "_9530010_EidersperrwerkAP_Vh.zrx") - -def main(): - """Main function of the script. Reads the data from the MOS.zip file, extracts the relevant data and writes it .zrx files.""" - base_dir = parse_arguments() - mos_file_name = base_dir / "MOS.zip" - dat_file,tsVorh = get_dat_file_name(mos_file_name) - target_file = make_target_file_name(base_dir,tsVorh) - - df = read_dat_file(mos_file_name,dat_file,tsVorh) - writeZrxp(target_file,df) - - #Resample and cut data to fit the 3h interval, hourly sample rate used by ICON - df2 = df[["forecast","pegel"]].resample("1h",on="forecast").first() - df2["forecast"] = df2.index - offset = get_offset(df2["forecast"].iloc[0].hour) - df2 = df2.iloc[offset:] - df2.loc[:,"timestamp"] = df2["forecast"].iloc[0] - pd.Timedelta("1h") - df2["member"] = 0 - df2 = df2[["timestamp","forecast","member","pegel"]] - - #Use the one prediction as forecasts for the next the possible predict_db runs - #If another bsh forecast is made before the next predict_db run, the forecast will be overwritten in the database. - for offset in [0,3,6]: - df3 = df2.copy() - df3["timestamp"] = df3["timestamp"] + pd.Timedelta(hours=offset) - target_file = make_target_file_name(base_dir,df3["timestamp"].iloc[0]) - writeZrxp(target_file,df3[offset:offset+48]) - - - - -if __name__ == "__main__": - main() diff --git a/src/scripts/extract_metrics.py b/src/scripts/extract_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5f170d4042f18d32d1782cd3392b936d0b315aba --- /dev/null +++ b/src/scripts/extract_metrics.py @@ -0,0 +1,54 @@ +import argparse +from ast import arg +from tbparse import SummaryReader +import os +from pathlib import Path + + + + +def extract_summaries_to_csv(root_dir,metrics=None,force=False): + """Takes a root directory and extracts metrics from the tensorboard file into a metrics.csv file. + Tensorboard files look something like events.out.tfevents.1701686993.kiel.2222904.0 + + default metrics are "hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood" + + + Args: + root_dir (str/path): Directory with folders containing tensorboard files and models. Those are expected to beging with version + metrics (list[str], optional): List of metrics to extract. Defaults to None. + """ + + if metrics is None: + metrics = ["hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood"] + + for root, _, files in os.walk(root_dir): + if root.split("/")[-1].startswith("version"): + if "metrics.csv" in files and not force: + print(f"Skipping {root} as metrics.csv already exists. Use --force to overwrite") + continue + #log_dir = "../../../data-project/KIWaVo/models/lfu/willenscharen6/lightning_logs/version_0/" + reader_hp = SummaryReader(root,pivot=True) + df =reader_hp.scalars + df = df.drop(["step","epoch"],axis=1)[1:49] + df = df[metrics] + df.columns = df.columns.str.slice(3) + df.to_csv(root + "/metrics.csv") + print(f"Extracted metrics for {root}") + + +def main(): + parser = argparse.ArgumentParser("Extract metrics from tensorboard files to csv. Will check folders that start with 'version'") + parser.add_argument("root_dir", type=str, help="Root directory with tensorboard files") + parser.add_argument("--metrics", type=str, help="""Metrics to extract. + Default: "hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood" + It is probably easiert to modify the source code if you want to change the metrics. + """, nargs="+") + parser.add_argument("--force", action="store_true", help="Force overwrite of existing metrics.csv files") + + args = parser.parse_args() + print(args) + extract_summaries_to_csv(args.root_dir,args.metrics,args.force) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/utils/db_tools/db_logging.py b/src/utils/db_tools/db_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..66e15c7e9a50bde1e4809d4efa109400bae23c41 --- /dev/null +++ b/src/utils/db_tools/db_logging.py @@ -0,0 +1,57 @@ +import logging +from utils.db_tools.orm_classes import Log +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +import pandas as pd + +class OracleDBHandler(logging.Handler): + """ + adapted from https://gist.github.com/ykessler/2662203 + """ + + def __init__(self, con): + logging.Handler.__init__(self) + self.con = con + self.cursor = con.cursor() + self.engine = create_engine("oracle+oracledb://", creator=lambda: con) + self.session = Session(bind=self.engine) + + # self.query = "INSERT INTO LOG VALUES(:1, :2, :3, :4, :5, :6, :7)" + + # Create table if needed: + table_list = [ + row[0] for row in self.cursor.execute("SELECT table_name FROM user_tables") + ] + if "LOG" not in table_list: + query = """CREATE TABLE LOG( + ID int GENERATED BY DEFAULT AS IDENTITY (START WITH 1), + Created TIMESTAMP, + LogLevelName varchar(32), + Message varchar(256), + Module varchar(64), + FuncName varchar(64), + LineNo int, + Exception varchar(256) + CONSTRAINT PK_LOG PRIMARY KEY (ID) + ) ;""" + self.cursor.execute(query) + self.con.commit() + + def emit(self, record): + if record.exc_info: + record.exc_text = logging._defaultFormatter.formatException(record.exc_info) + else: + record.exc_text = "" + + log = Log( + created=pd.to_datetime(record.created,unit="s"), + loglevelname=record.levelname, + message=record.message, + module=record.module, + funcname=record.funcName, + lineno=record.lineno, + exception=record.exc_text, + gauge=record.gauge_id, + ) + self.session.add(log) + self.session.commit() \ No newline at end of file diff --git a/src/utils/db_tools.py b/src/utils/db_tools/db_tools.py similarity index 62% rename from src/utils/db_tools.py rename to src/utils/db_tools/db_tools.py index 65af2e2533853a74b1b2acd76919407d62b38eea..df4a9e558c44a56a8cf0ea6ddfbc61bb81613732 100644 --- a/src/utils/db_tools.py +++ b/src/utils/db_tools/db_tools.py @@ -3,41 +3,97 @@ A collection of classes and functions for interacting with the oracle based Wavo """ import logging -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path -from typing import List, Tuple +from typing import List import warnings - +from functools import partial # from warnings import deprecated (works from python 3.13 on) +from functools import reduce import numpy as np import oracledb import pandas as pd from sqlalchemy import create_engine -from sqlalchemy import select, bindparam, between, update, insert +from sqlalchemy import select, bindparam, between, update from sqlalchemy.orm import Session -from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2, NUMBER +from sqlalchemy.dialects.oracle import FLOAT, TIMESTAMP, VARCHAR2 from sqlalchemy.exc import IntegrityError import torch -from torch.utils.data import Dataset, DataLoader -import lightning.pytorch as pl - -from tqdm import tqdm import utils.helpers as hp -from utils.orm_classes import ( +from utils.utility import scale_standard +from utils.db_tools.orm_classes import ( ModellSensor, SensorData, InputForecasts, PegelForecasts, Log, Sensor, - InputForecastsMeta + InputForecastsMeta, + Modell ) - +from torch.utils.data import DataLoader,Dataset +#from data_tools.datasets import TimeSeriesDataSet # pylint: disable=unsupported-assignment-operation # pylint: disable=unsubscriptable-object +# pylint: disable=line-too-long + + + +class DBDataSet(Dataset): + def __init__(self, df_scaled,in_size,ext_meta,df_ext=None) -> None: + self.df_scaled = df_scaled + self.in_size = in_size + self.df_ext = df_ext + self.ext_meta = ext_meta #TODO rename + self.y = torch.Tensor(([0])) # dummy + + + if df_ext is None: + self.max_ens_size = 1 + self.common_indexes = df_scaled.index[self.in_size:] #TODO check if this is correct + else: + self.max_ens_size = max([i_meta.ensemble_members for i_meta in ext_meta]) + self.unique_indexes = [df_ext[df_ext["sensor_name"] == meta.sensor_name].index.unique() for meta in ext_meta] + self.common_indexes = reduce(np.intersect1d, self.unique_indexes) + + def __len__(self) -> int: + return len(self.common_indexes)*self.max_ens_size + + def get_whole_index(self): + if self.df_ext is None: + rows = [(self.common_indexes[idx // self.max_ens_size],-1) for idx in range(len(self))] + else: + rows = [(self.common_indexes[idx // self.max_ens_size],idx % self.max_ens_size) for idx in range(len(self))] + return pd.DataFrame(rows,columns=["tstamp","member"]) + + + def __getitem__(self, idx): + if idx >= len(self): + raise IndexError(f"Index {idx} is out of range, dataset has length {len(self)}") + + tstamp = self.common_indexes[idx // self.max_ens_size] + ens_num = idx % self.max_ens_size + + x = self.df_scaled[:tstamp][-self.in_size:].astype("float32") + + if self.df_ext is not None: + for meta in self.ext_meta: + if meta.ensemble_members == 1: + #ext_values = self.df_ext_forecasts.loc[tstamp,meta_data.sensor_name] + ext_values = self.df_ext[self.df_ext["sensor_name"]== meta.sensor_name].loc[tstamp].iloc[2:].values + else: + #TODO hier crasht es wenn für einen zeitpunkt entweder ensemble oder hauptläufe gibt + ext_values = self.df_ext[(self.df_ext["sensor_name"]== meta.sensor_name) & (self.df_ext["member"]== ens_num)].loc[tstamp].iloc[2:].values + + + col_idx = x.columns.get_loc(meta.sensor_name) + x.iloc[-48:,col_idx] = ext_values.astype("float32") + + #y = self.y[idx+self.in_size:idx+self.in_size+self.out_size] + return x.values, self.y class OracleWaVoConnection: """Class for writing/reading from the Wavo Database.""" @@ -65,43 +121,85 @@ class OracleWaVoConnection: """ # self.cur_model_folder = model_folder - model = hp.load_model(gauge_config["model_folder"]) model.eval() - created = datetime.now() - for end_time in tqdm(self.times): - # load sensor sensor input data - df_base_input = self.load_input_db( - in_size=model.in_size, - end_time=end_time, - gauge_config=gauge_config, - created=created, - ) - if df_base_input is None: - continue - - # predict - # Get all external forecasts for this gauge - #stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name.in_(bindparam("external_fcst"))) - #params = {"external_fcst" : gauge_config["external_fcst"]} - #with Session(self.engine) as session: - # x = session.scalars(statement=stmt, params=params).fetchall() - # sensor_names = [el.sensor_name for el in x] - #params2 = {"ext_forecasts" : sensor_names, "tstamp" : end_time} - stmst2 = select(InputForecasts).where(InputForecasts.sensor_name.in_(bindparam("ext_forecasts")),InputForecasts.tstamp == (bindparam("tstamp"))) - params2 = {"ext_forecasts" : gauge_config["external_fcst"], "tstamp" : end_time} - df_temp = pd.read_sql(sql=stmst2,con=self.engine,index_col="tstamp",params=params2) + created = datetime.now(timezone) + #get the metadata for the external forecasts (to access them later from the db) + input_meta_data = self._ext_fcst_metadata(gauge_config) - for member in self.members: - self.handle_member( - gauge_config, - model, - df_base_input, - member, - end_time, - created, + if len(self.times)> 1: + df_base_input = self.load_input_db(model.in_size, self.times[-1], gauge_config,created) + + #load external forecasts + if self.main_config.get("fake_external"): + df_temp = None + else: + stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(gauge_config["external_fcst"]), + between(InputForecasts.tstamp, self.times[0], self.times[-1]), + InputForecasts.member.in_(self.members)) + + df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp") + + dataset = self._make_dataset(gauge_config,df_base_input, model,input_meta_data,df_temp) + df_pred = pred_multiple_db(model, dataset) + df_pred["created"] = created + df_pred["sensor_name"] = gauge_config["gauge"] + df_pred["model"] = gauge_config["model_folder"].name + + stmt =select(PegelForecasts.tstamp,PegelForecasts.member).where( + PegelForecasts.model==gauge_config["model_folder"].name, + PegelForecasts.sensor_name==gauge_config["gauge"], + PegelForecasts.tstamp.in_(df_pred["tstamp"]), + #between(PegelForecasts.tstamp, df_pred["tstamp"].min(), df_pred["tstamp"].max()), ) + + df_already_inserted = pd.read_sql(sql=stmt,con=self.engine) + merged = df_pred.merge(df_already_inserted[['tstamp', 'member']],on=['tstamp', 'member'], how='left',indicator=True) + df_pred_create = merged[merged['_merge'] == 'left_only'].drop('_merge', axis=1) + df_pred_update = merged[merged['_merge'] == 'both'].drop('_merge', axis=1) + + with Session(self.engine) as session: + forecast_objs = df_pred_create.apply(lambda row: PegelForecasts(**row.to_dict()), axis=1).tolist() + session.add_all(forecast_objs) + session.execute(update(PegelForecasts),df_pred_update.to_dict(orient='records')) + session.commit() + + else: + #TODO this is weird, cause it's just one at the moment + for end_time in self.times: + # load sensor sensor input data + df_base_input = self.load_input_db( + in_size=model.in_size, + end_time=end_time, + gauge_config=gauge_config, + created=created, + ) + if df_base_input is None: + continue + + # load external forecasts + stmst = select(InputForecasts).where(InputForecasts.sensor_name.in_(bindparam("ext_forecasts")),InputForecasts.tstamp == (bindparam("tstamp"))) + params = {"ext_forecasts" : gauge_config["external_fcst"], "tstamp" : end_time} + df_temp = pd.read_sql(sql=stmst,con=self.engine,index_col="tstamp",params=params) + + if len(gauge_config["external_fcst"]) != len(input_meta_data): + logging.error("Not all external can be found in InputForecastsMeta, only %s",[x.sensor_name for x in input_meta_data],extra={"gauge" : gauge_config["gauge"]}) + return + + for member in self.members: + self.handle_member( + gauge_config, + model, + df_base_input, + member, + end_time, + input_meta_data, + df_temp, + created + ) + + def handle_member( @@ -111,6 +209,8 @@ class OracleWaVoConnection: df_base_input: pd.DataFrame, member: int, end_time, + input_meta_data, + df_temp, created: pd.Timestamp = None, ) -> None: """ @@ -123,26 +223,51 @@ class OracleWaVoConnection: df_base_input (DataFrame): The base input DataFrame with measure values, still need the actual precipitation forecast. member (int): The member identifier. end_time (pd.Timestamp): The timestamp of the forecast. + input_meta_data (list): List of external forecast metadata. + df_temp (DataFrame): DataFrame with external forecasts. created (pd.Timestamp): The timestamp of the start of the forecast creation. Returns: None """ - if member == -1: - df_input = df_base_input.fillna(0) + try: + if member == -1: + df_input = df_base_input.fillna(0) + else: + # replace fake forecast with external forecast + df_input = df_base_input.copy() + for input_meta in input_meta_data: + if input_meta.ensemble_members == 1: + temp_member = 0 + elif input_meta.ensemble_members == 21: + temp_member = member + else: + temp_member = member #TODO was hier tun? + + fcst_data = df_temp[(df_temp["member"]== temp_member) & (df_temp["sensor_name"] == input_meta.sensor_name)] + if len(fcst_data) == 0: + raise LookupError(f"External forecast {input_meta.sensor_name} not found for member {temp_member}") + + # Replace missing values with forecasts. + # It is guaranteed that the only missing values are in the "future" (after end_time) + assert df_input[:-48].isna().sum().sum() == 0 + nan_indices = np.where(df_input[input_meta.sensor_name].isna())[0] - (144 - 48) + nan_indices = nan_indices[nan_indices >= 0] + nan_indices2 = nan_indices + (144 - 48) + col_idx = df_input.columns.get_loc(input_meta.sensor_name) + df_input.iloc[nan_indices2, col_idx] = fcst_data.values[0,2:].astype("float32")[nan_indices] + + except LookupError as e: + logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]}) + y = torch.Tensor(48).fill_(np.nan) else: + y = pred_single_db(model, df_input) + finally: + self.insert_forecast( + y,gauge_config,end_time, member, created + ) - # replace fake forecast with external forecast - df_input = self.merge_synth_fcst(df_base_input, member, end_time, gauge_config,created) - if df_input is None: - return - - y = pred_single_db(model, df_input) - self.insert_forecast( - y, gauge_config["gauge"], gauge_config["model_folder"].name, end_time, member, created - ) - def load_input_db(self, in_size, end_time, gauge_config,created) -> pd.DataFrame: """ Loads input data from the database based on the current configuration. @@ -159,7 +284,14 @@ class OracleWaVoConnection: columns=gauge_config["columns"] external_fcst=gauge_config["external_fcst"] - start_time = end_time - pd.Timedelta(in_size - 1, "hours") + + if len(self.times) > 1: + first_end_time = self.times[0] + start_time = first_end_time - pd.Timedelta(in_size - 1, "hours") + if self.main_config.get("fake_external"): + end_time = end_time + pd.Timedelta(48, "hours") + else: + start_time = end_time - pd.Timedelta(in_size - 1, "hours") stmt = select(SensorData).where( between(SensorData.tstamp, bindparam("start_time"), bindparam("end_time")), @@ -184,7 +316,7 @@ class OracleWaVoConnection: df_rest = df_main[end_time + pd.Timedelta("1h") :] df_rest.index = df_rest.index - pd.Timedelta("48h") - if len(df_input) != in_size: + if len(self.times)== 1 and len(df_input) != in_size: raise MissingTimestampsError( f"Some timestamps are completely missing, found {len(df_input)}/{in_size} hours from {start_time} to {end_time}" ) @@ -217,6 +349,9 @@ class OracleWaVoConnection: # This is a little messy because we don't want to interpolate those values and need to be careful about missing timestamps df_input.loc[df_rest.index, col] = df_rest[col] + if self.main_config.get("fake_external"): + df_input = df_input.iloc[:-48] + except MissingTimestampsError as e: logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]}) df_input = None @@ -224,31 +359,24 @@ class OracleWaVoConnection: logging.error(e.args[0],extra={"gauge":gauge_config["gauge"]}) df_input = None except KeyError as e: + #logging.error( + # "Some columns are missing in the database, found %s", + # df_main.columns,extra={"gauge":gauge_config["gauge"]}) logging.error( "Some columns are missing in the database, found %s", - df_main.columns,extra={"gauge":gauge_config["gauge"]}) + df_main["sensor_name"].unique(),extra={"gauge":gauge_config["gauge"]}) logging.error(e.args[0]) df_input = None - #except LookupError as e: - # if df_input is not None: - # logging.error( - # "No data for the chosen timeperiod up to %s and columns in the database.", - # end_time, - # extra={"gauge":gauge_config["gauge"]} - # ) - # df_input = None - # logging.error(e.args[0]) - if df_input is None: logging.error("Input sensordata could not be loaded, inserting empty forecasts",extra={"gauge":gauge_config["gauge"]}) + #TODO logging here is weird/wrong with multiple predictions. y = torch.Tensor(48).fill_(np.nan) for member in self.members: try: self.insert_forecast( y, - gauge_config["gauge"], - gauge_config["model_folder"].name, + gauge_config, end_time, member, created, @@ -260,57 +388,6 @@ class OracleWaVoConnection: return df_input - def merge_synth_fcst( - self, - df_base_input: pd.DataFrame, - member: int, - end_time: pd.Timestamp, - gauge_config: dict, - created: pd.Timestamp = None, - ) -> pd.DataFrame: - """ - Merges external forecasts into the base input dataframe by replacing the old wrong values. - - Args: - df_base_input (pd.DataFrame): The base input dataframe. - member (int): The member identifier. - end_time (pd.Timestamp): The timestamp of the forecast. - gauge_config (dict): Settings for the gauge, especially the external forecasts and the gauge name. - created (pd.Timestamp): The timestamp of the start of the forecast creation. - - Returns: - pd.DataFrame: The merged dataframe with external forecasts. - """ - #TODO: besprechen ob wir einfach alle spalten durch vorhersagen ergänzen falls vorhanden? - df_input = df_base_input.copy() - - for col in gauge_config["external_fcst"]: - try: - ext_fcst = self.get_ext_forecast(col, member, end_time) - except AttributeError: - logging.error( - "External forecast %s time %s ensemble %s is missing", - col, - end_time, - member, - extra={"gauge":gauge_config["gauge"]} - ) - # if some external forecasts are missing, insert NaNs and skip prediction - y = torch.Tensor(48).fill_(np.nan) - self.insert_forecast(y, gauge_config["gauge"], gauge_config["model_folder"].name, end_time, member, created) - return None - else: - # Replace missing values with forecasts. - # It is guaranteed that the only missing values are in the "future" (after end_time) - assert df_input[:-48].isna().sum().sum() == 0 - nan_indices = np.where(df_input[col].isna())[0] - (144 - 48) - nan_indices = nan_indices[nan_indices >= 0] - nan_indices2 = nan_indices + (144 - 48) - df_input.iloc[nan_indices2, df_input.columns.get_loc(col)] = np.array( - ext_fcst - )[nan_indices] - - return df_input def add_zrxp_data(self, zrxp_folders: List[Path]) -> None: """Adds zrxp data to the database @@ -332,29 +409,18 @@ class OracleWaVoConnection: zrxp_file, skiprows=3, header=None, sep=" ", parse_dates=[0] ) - # zrxp_time = df_zrxp.iloc[0, 0] - # member = int(df_zrxp.iloc[0, 2]) - # try: - # # self.insert_external_forecast(zrxp_time, fcst_name, df_zrxp[3], member) - # self.insert_external_forecast( - # zrxp_time, vhs_gebiet, df_zrxp[3], member - # ) try: if len(df_zrxp.columns) == 4: zrxp_time = df_zrxp.iloc[0, 0] member = int(df_zrxp.iloc[0, 2]) - self.insert_external_forecast( - zrxp_time, vhs_gebiet, df_zrxp[3], member - ) + forecast = df_zrxp[3] + else: zrxp_time = df_zrxp.iloc[0, 0] - pd.Timedelta(hours=1) member = 0 - # self.insert_external_forecast(zrxp_time, fcst_name, df_zrxp[3], member) - self.insert_external_forecast( - zrxp_time, vhs_gebiet, df_zrxp[1], member - ) + forecast = df_zrxp[1] + self.insert_external_forecast(zrxp_time, vhs_gebiet, forecast, member) except oracledb.IntegrityError as e: - if e.args[0].code == 2291: logging.warning( "%s Does the sensor_name %s exist in MODELL_SENSOR.VHS_GEBIET?", @@ -381,10 +447,6 @@ class OracleWaVoConnection: Returns: List(int): The external forecast data. """ - # TODO check new column in MODELL_SENSOR - - #TODO this would be more efficient outside the member loop - #TODO error handling stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name == (bindparam("sensor_name"))) params = {"sensor_name" : sensor_name} with Session(self.engine) as session: @@ -393,10 +455,6 @@ class OracleWaVoConnection: if input_meta is None: raise AttributeError(f"Forecast {sensor_name} not found in Table InputForecastsMeta") - #sensor = self.get_sensor_name(vhs_gebiet) # maybe check for none? - #if sensor is None: - # raise AttributeError(f"Sensor {vhs_gebiet} not found in database") - stmt2 = select(InputForecasts).where( InputForecasts.tstamp == bindparam("tstamp"), InputForecasts.sensor_name == bindparam("sensor_name"), @@ -429,7 +487,6 @@ class OracleWaVoConnection: Returns: None """ - sensor = self.get_sensor_name(vhs_gebiet) if sensor is None: return @@ -466,8 +523,7 @@ class OracleWaVoConnection: def insert_forecast( self, forecast: torch.Tensor, - sensor_name: str, - model_name: str, + gauge_config: dict, end_time: pd.Timestamp, member: int, created: pd.Timestamp = None, @@ -478,16 +534,22 @@ class OracleWaVoConnection: Args: forecast (Tensor): The forecast data to be inserted. + gauge_config (dict): The configuration for the gauge. + end_time (pd.Timestamp): The timestamp of the forecast. member (int): The member identifier. - + created (pd.Timestamp, optional): The timestamp of the forecast creation. Returns: None """ + sensor_name = gauge_config["gauge"] + model_name = gauge_config["model_folder"].name + # Turn tensor into dict to if forecast.isnan().any(): fcst_values = {f"h{i}": None for i in range(1, 49)} else: fcst_values = {f"h{i}": forecast[i - 1].item() for i in range(1, 49)} + self._export_zrxp(forecast, gauge_config, end_time, member, sensor_name, model_name, created) stmt = select(PegelForecasts).where( PegelForecasts.tstamp == bindparam("tstamp"), @@ -534,7 +596,68 @@ class OracleWaVoConnection: #try: session.commit() - + + def _export_zrxp(self, forecast, gauge_config, end_time, member, sensor_name, model_name, created): + if self.main_config.get("export_zrxp") and member == 0: + + try: + df_models = pd.read_sql(sql=select(Modell.modellname, Modell.modelldatei), con=self.engine) + df_models["modelldatei"] = df_models["modelldatei"].apply(Path) + df_models["modelldatei"] = df_models["modelldatei"].apply(lambda x: x.resolve()) + real_model_name = df_models[df_models["modelldatei"]== gauge_config["model_folder"].resolve()]["modellname"].values[0] + except IndexError: + logging.error("Model %s not found in database, skipping zrxp export",gauge_config["model_folder"]) + else: + target_file =self.main_config["zrxp_out_folder"] / f"{end_time.strftime('%Y%m%d%H')}_{sensor_name.replace(',','_')}_{model_name}.zrx" + #2023 11 19 03 + df_zrxp = pd.DataFrame(forecast,columns=["value"]) + df_zrxp["timestamp"] = end_time + df_zrxp["forecast"] = pd.date_range(start=end_time,periods=49,freq="1h")[1:] + df_zrxp["member"] = member + df_zrxp = df_zrxp[['timestamp','forecast','member','value']] + + + with open(target_file , 'w', encoding="utf-8") as file: + file.write('#REXCHANGEWISKI.' + real_model_name + '.W.KNN|*|\n') + file.write('#RINVAL-777|*|\n') + file.write('#LAYOUT(timestamp,forecast, member,value)|*|\n') + + df_zrxp.to_csv(path_or_buf = target_file, + header = False, + index=False, + mode='a', + sep = ' ', + date_format = '%Y%m%d%H%M') + + + def _ext_fcst_metadata(self, gauge_config): + # Get all external forecasts for this gauge + stmt = select(InputForecastsMeta).where(InputForecastsMeta.sensor_name.in_(bindparam("external_fcst"))) + with Session(self.engine) as session: + input_meta_data = session.scalars(statement=stmt, params={"external_fcst" : gauge_config["external_fcst"]}).fetchall() + + return input_meta_data + + + + def _make_dataset(self,gauge_config,df_input, model,input_meta_data,df_ext): + with warnings.catch_warnings(): + warnings.filterwarnings(action="ignore", category=UserWarning) + if model.differencing == 1: + df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) + #TODO time embedding stuff + df_scaled = pd.DataFrame(model.scaler.transform(df_input.values),index=df_input.index,columns=df_input.columns) + trans_cols = [f"h{i}" for i in range(1,49)] + for col in gauge_config["external_fcst"]: + idx = gauge_config["columns"].index(col) + scale_curry = partial(scale_standard,mean=model.scaler.mean_[idx],std=model.scaler.scale_[idx]) #TODO maybe people want to use different scalers + + df_ext.loc[df_ext["sensor_name"]==col,trans_cols] = df_ext.loc[df_ext["sensor_name"]==col,trans_cols].apply(scale_curry) + + ds = DBDataSet(df_scaled,model.in_size,input_meta_data,df_ext) + return ds + + def maybe_update_tables(self) -> None: """ Updates the database tables if necessary based on the configuration settings. @@ -551,90 +674,6 @@ class OracleWaVoConnection: if self.main_config.get("load_zrxp"): self.add_zrxp_data(self.main_config["zrxp_folder"]) - def maybe_create_tables(self) -> None: - """ - Checks if the required tables exist in the database and creates them if they don't exist. - """ - print( - "THIS METHOD IS DEPRECATED, the LFU created tables, not sure how to map their constraints to the oracle_db library" - ) - return - - with self.con.cursor() as cursor: - table_list = [ - row[0] for row in cursor.execute("SELECT table_name FROM user_tables") - ] - - if "SENSOR_DATA" not in table_list: - logging.warning( - "Warning: Table SENSOR_DATA not in database, creating new empty table" - ) - self.create_sensor_table() - - if "PEGEL_FORECASTS" not in table_list: - logging.warning( - "Warning: Table PEGEL_FORECASTS not in database, creating new empty table" - ) - self.create_forecast_table() - - if "INPUT_FORECASTS" not in table_list: - logging.warning( - "Warning: Table INPUT_FORECASTS not in database, creating new empty table" - ) - self.create_external_forecast_table() - - def create_sensor_table(self) -> None: - """ - Creates the table SENSOR_DATA with columns tstamp, sensor_name, sensor_value. - """ - with self.con.cursor() as cursor: - query = """CREATE TABLE SENSOR_DATA ( - tstamp TIMESTAMP NOT NULL, - sensor_name varchar(256) NOT NULL, - sensor_value FLOAT, - PRIMARY KEY (tstamp, sensor_name))""" - cursor.execute(query) - self.con.commit() - - def create_forecast_table(self) -> None: - """ - Creates the table PEGEL_FORECASTS with columns tstamp, gauge, model, member, created and a value column for each of the 48 forecast hours. - """ - with self.con.cursor() as cursor: - query = ( - """CREATE TABLE PEGEL_FORECASTS ( - tstamp TIMESTAMP NOT NULL, - sensor_name varchar(256) NOT NULL, - model varchar(256) NOT NULL, - member INTEGER NOT NULL, - created TIMESTAMP NOT NULL, - """ - + "\n".join([f"h{i} FLOAT," for i in range(1, 49)]) - + """ - PRIMARY KEY (tstamp,sensor_name,model,member) - )""" - ) - cursor.execute(query) - self.con.commit() - - def create_external_forecast_table(self) -> None: - """ - Creates the table INPUT_FORECASTS with columns tstamp, sensor_name, member and a value column for each of the 48 forecast hours. - """ - with self.con.cursor() as cursor: - query = ( - """CREATE TABLE INPUT_FORECASTS ( - tstamp TIMESTAMP NOT NULL, - sensor_name varchar(255) NOT NULL, - member INTEGER NOT NULL, - """ - + "\n".join([f"h{i} FLOAT," for i in range(1, 49)]) - + """ - PRIMARY KEY (tstamp,sensor_name,member) - )""" - ) - cursor.execute(query) - self.con.commit() def get_member_list(self) -> List[int]: """Returns a list of member identifiers based on the main configuration. @@ -665,13 +704,13 @@ class OracleWaVoConnection: if "range" in self.main_config and self.main_config["range"]: return list( pd.date_range( - self.main_config["start"], self.main_config["end"], freq="H" + self.main_config["start"], self.main_config["end"], freq="h" ) ) return [pd.to_datetime(self.main_config["start"])] def get_sensor_name(self, vhs_gebiet: str) -> Sensor: - """Returns the sensor name based on the vhs_gebiet, by looking it up in the table ModellSensor and checking for uniqueness of combination and existence. + """Returns the sensor name based on the vhs_gebiet, by looking it up in the table INPUT_FORECASTS_META Args: vhs_gebiet (str): The vhs_gebiet name. @@ -680,28 +719,12 @@ class OracleWaVoConnection: Sensor: A Sensor object (a row from the table Sensor). """ + stmt = select(InputForecastsMeta).where(InputForecastsMeta.vhs_gebiet == bindparam("vhs_gebiet")) + params = {"vhs_gebiet" : vhs_gebiet} + with Session(self.engine) as session: + input_meta = session.scalar(statement=stmt, params=params) - with Session(bind=self.engine) as session: - # get all sensor names for the vhs_gebiet from the model_sensor table and check if there is only one. - stmt = select(ModellSensor.sensor_name).where( - ModellSensor.vhs_gebiet == bindparam("vhs_gebiet") - ) - sensor_names = set( - session.scalars(stmt, params={"vhs_gebiet": vhs_gebiet}).all() - ) - if len(sensor_names) == 0: - logging.warning("No sensor_name found for %s", vhs_gebiet) - self.not_found.append(vhs_gebiet) - - return - elif len(sensor_names) > 1: - logging.warning("Multiple sensor_names found for %s", vhs_gebiet) - return - sensor_name = list(sensor_names)[0] - sensor = session.scalar( - select(Sensor).where(Sensor.sensor_name == sensor_name) - ) - return sensor + return input_meta # @deprecated("This method is deprecated, data can be loaded directly from wiski (Not part of this package).") def add_sensor_data(self, sensor_folder, force=False) -> None: @@ -804,18 +827,39 @@ def pred_single_db(model, df_input: pd.DataFrame) -> torch.Tensor: warnings.filterwarnings(action="ignore", category=UserWarning) if model.differencing == 1: df_input["d1"] = df_input.iloc[:, model.gauge_idx].diff().fillna(0) - + #TODO time embedding stuff x = model.scaler.transform( df_input.values ) # TODO values/df column names wrong/old x = torch.tensor(data=x, device=model.device, dtype=torch.float32).unsqueeze(dim=0) - # y = ut.inv_standard(model(x), model.scaler.mean_[model.target_idx], model.scaler.scale_[model.target_idx]) y = model.predict_step((x, None), 0) return y[0] + +@torch.no_grad() +def pred_multiple_db(model, ds: Dataset) -> torch.Tensor: + """ + Predicts the output for a single input data point. + + Args: + model (LightningModule): The model to use for prediction. + df_input (pandas.DataFrame): Input data point as a DataFrame. + + Returns: + torch.Tensor: Predicted output as a tensor. + """ + + data_loader = DataLoader(ds, batch_size=256,num_workers=1) + pred = hp.get_pred(model, data_loader,None) + df_pred = ds.get_whole_index() + df_pred = pd.concat([df_pred, pd.DataFrame(pred,columns=[f"h{i}" for i in range(1,49)])], axis=1) + + return df_pred + + def get_merge_forecast_query(out_size: int = 48) -> str: """Returns a merge query for the PEGEL_FORECASTS table with the correct number of forecast hours.""" @@ -850,57 +894,7 @@ def get_merge_forecast_query(out_size: int = 48) -> str: return merge_forecast_query -class OracleDBHandler(logging.Handler): - """ - adapted from https://gist.github.com/ykessler/2662203 - """ - def __init__(self, con): - logging.Handler.__init__(self) - self.con = con - self.cursor = con.cursor() - self.engine = create_engine("oracle+oracledb://", creator=lambda: con) - self.session = Session(bind=self.engine) - - # self.query = "INSERT INTO LOG VALUES(:1, :2, :3, :4, :5, :6, :7)" - - # Create table if needed: - table_list = [ - row[0] for row in self.cursor.execute("SELECT table_name FROM user_tables") - ] - if "LOG" not in table_list: - query = """CREATE TABLE LOG( - ID int GENERATED BY DEFAULT AS IDENTITY (START WITH 1), - Created TIMESTAMP, - LogLevelName varchar(32), - Message varchar(256), - Module varchar(64), - FuncName varchar(64), - LineNo int, - Exception varchar(256) - CONSTRAINT PK_LOG PRIMARY KEY (ID) - ) ;""" - self.cursor.execute(query) - self.con.commit() - - def emit(self, record): - if record.exc_info: - record.exc_text = logging._defaultFormatter.formatException(record.exc_info) - else: - record.exc_text = "" - - log = Log( - created=pd.to_datetime(record.created,unit="s"), - loglevelname=record.levelname, - message=record.message, - module=record.module, - funcname=record.funcName, - lineno=record.lineno, - exception=record.exc_text, - gauge=record.gauge_id, - ) - self.session.add(log) - self.session.commit() class MissingTimestampsError(LookupError): pass diff --git a/src/utils/orm_classes.py b/src/utils/db_tools/orm_classes.py similarity index 81% rename from src/utils/orm_classes.py rename to src/utils/db_tools/orm_classes.py index a80a75f680486c6657a02544e5bf27309cef5722..eea22b84b4fbd502ee68e58e5fcd219bf12caa11 100644 --- a/src/utils/orm_classes.py +++ b/src/utils/db_tools/orm_classes.py @@ -4,7 +4,7 @@ The alternative is to use table reflection from sqlalchemy. from typing import List, Optional import datetime -from sqlalchemy import Double, ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, TIMESTAMP, VARCHAR,Identity +from sqlalchemy import Double, ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, TIMESTAMP, VARCHAR,Identity, String, UniqueConstraint, text from sqlalchemy.dialects.oracle import NUMBER from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -40,6 +40,11 @@ class InputForecastsMeta(Base): vhs_gebiet : Mapped[str] = mapped_column(VARCHAR(256), primary_key=True) ensemble_members : Mapped[int] = mapped_column(NUMBER(5, 0, False)) + def __repr__(self) -> str: + return (f"InputForecastsMeta(sensor_name='{self.sensor_name}', " + f"vhs_gebiet='{self.vhs_gebiet}', " + f"ensemble_members={self.ensemble_members})") + class Modell(Base): __tablename__ = 'modell' @@ -48,15 +53,37 @@ class Modell(Base): Index('id_modellname', 'modellname', unique=True) ) - id: Mapped[float] = mapped_column(NUMBER(10, 0, False), primary_key=True) + #id: Mapped[float] = mapped_column(NUMBER(10, 0, False), primary_key=True) + id: Mapped[float] = mapped_column( + NUMBER(10, 0, False), + Identity( + start=1, + increment=1, + minvalue=1, + maxvalue=9999999999, + cycle=False, + cache=20, + order=False + ), + primary_key=True + ) + mstnr: Mapped[Optional[str]] = mapped_column(VARCHAR(30)) modellname: Mapped[Optional[str]] = mapped_column(VARCHAR(256)) modelldatei: Mapped[Optional[str]] = mapped_column(VARCHAR(256)) aktiv: Mapped[Optional[float]] = mapped_column(NUMBER(1, 0, False)) kommentar: Mapped[Optional[str]] = mapped_column(VARCHAR(1024)) + in_size: Mapped[Optional[int]] = mapped_column( NUMBER(38,0)) + target_sensor_name: Mapped[Optional[str]] = mapped_column(VARCHAR(256)) + metric: Mapped[List['Metric']] = relationship('Metric', back_populates='model') modell_sensor: Mapped[List['ModellSensor']] = relationship('ModellSensor', back_populates='modell') + def __repr__(self) -> str: + return (f"Modell(id={self.id}, " + f"mstnr='{self.mstnr}', " + f"modellname='{self.modellname}', " + f"aktiv={self.aktiv})") class Sensor(Base): __tablename__ = 'sensor' @@ -77,6 +104,11 @@ class Sensor(Base): pegel_forecasts: Mapped[List['PegelForecasts']] = relationship('PegelForecasts', back_populates='sensor') sensor_data: Mapped[List['SensorData']] = relationship('SensorData', back_populates='sensor') + def __repr__(self) -> str: + return (f"Sensor(sensor_name='{self.sensor_name}', " + f"mstnr='{self.mstnr}', " + f"parameter_kurzname='{self.parameter_kurzname}', " + f"ts_shortname='{self.ts_shortname}')") class TmpSensorData(Base): __tablename__ = 'tmp_sensor_data' @@ -90,7 +122,6 @@ class TmpSensorData(Base): - class InputForecasts(Base): __tablename__ = 'input_forecasts' __table_args__ = ( @@ -155,6 +186,31 @@ class InputForecasts(Base): sensor: Mapped['Sensor'] = relationship('Sensor', back_populates='input_forecasts') + +class Metric(Base): + __tablename__ = 'metric' + __table_args__ = ( + ForeignKeyConstraint(['model_id'], ['modell.id'], name='metric_to_model'), + PrimaryKeyConstraint('id', name='sys_c008840'), + Index('metric_unique', 'model_id', 'metric_name', 'timestep', 'tag', unique=True) + ) + #TODO Constraints verbessern (Startzeitpunkt + Endzeitpunkt eindeutig) + #TODO the server_default will probably not work at the lfu + #id: Mapped[float] = mapped_column(NUMBER(19, 0, False), primary_key=True, server_default=text('"C##MSPILS"."METRICS_SEQ"."NEXTVAL"')) + id: Mapped[float] = mapped_column(NUMBER(19, 0, False), + Identity(always=True, on_null=False, start=1, increment=1, minvalue=1, maxvalue=9999999999999999999999999999, cycle=False, cache=20, order=False), + primary_key=True) + model_id: Mapped[float] = mapped_column(NUMBER(10, 0, False)) + metric_name: Mapped[str] = mapped_column(VARCHAR(100)) + timestep: Mapped[int] = mapped_column(Integer) + start_time: Mapped[datetime.datetime] = mapped_column(TIMESTAMP) + end_time: Mapped[datetime.datetime] = mapped_column(TIMESTAMP) + value: Mapped[float] = mapped_column(Double) + tag: Mapped[Optional[str]] = mapped_column(VARCHAR(100)) + + model: Mapped['Modell'] = relationship('Modell', back_populates='metric') + + class ModellSensor(Base): __tablename__ = 'modell_sensor' __table_args__ = ( diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 7848312a5a75395c3d0c6011a107c0354a91b75b..56bf836d999dab1fda8d0ed1c138982523dfc831 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -85,10 +85,11 @@ def get_pred( if move_model: model.cuda() - pred = np.concatenate(pred) + y_pred = np.concatenate(pred) - y_pred = np.concatenate([np.expand_dims(y_true, axis=1), pred], axis=1) - y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1])) + if y_true is not None: + y_pred = np.concatenate([np.expand_dims(y_true, axis=1), pred], axis=1) + y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1])) return y_pred @@ -121,16 +122,25 @@ def load_model(model_dir: Path) -> WaVoLightningModule: optimizer="adam", ) except TypeError: - model = WaVoLightningModule.load_from_checkpoint( - checkpoint_path=next((model_dir / "checkpoints").iterdir()), - map_location=map_location, - scaler=config["scaler"], - optimizer="adam", - differencing=0, - gauge_idx=-1, - embed_time=False - ) - + try: + model = WaVoLightningModule.load_from_checkpoint( + checkpoint_path=next((model_dir / "checkpoints").iterdir()), + map_location=map_location, + scaler=config["scaler"], + optimizer="adam", + gauge_idx=-1, + embed_time=False + ) + except TypeError: + model = WaVoLightningModule.load_from_checkpoint( + checkpoint_path=next((model_dir / "checkpoints").iterdir()), + map_location=map_location, + scaler=config["scaler"], + optimizer="adam", + differencing=0, + gauge_idx=-1, + embed_time=False + ) return model diff --git a/src/utils/utility.py b/src/utils/utility.py index 954e1253f69577c17d0cd8b1e571e6140f752eff..46341c9aebbf1983497dc5576f3b8d98f7f3ffa1 100644 --- a/src/utils/utility.py +++ b/src/utils/utility.py @@ -91,6 +91,24 @@ def inv_minmax(x, xmin: float, xmax: float): return x * (xmax-xmin) + xmin +def scale_standard(x, mean: float =None, std: float=None): + """Standard scaling, also unpacks torch tensors if needed + + Args: + x (np.array,torch.Tensor or similar): array to inverse scaling + mean (float): mean value of a (training) dataset + std (float): standard deviation value of a (training) dataset + + Returns: + (np.array,torch.Tensor or similar): minmax scaled values + """ + if isinstance(x,torch.Tensor) and isinstance(mean, torch.Tensor) and isinstance(std, torch.Tensor): + return (x - mean) / std + if isinstance(mean, torch.Tensor): + mean = mean.item() + if isinstance(std, torch.Tensor): + std = std.item() + return (x - mean) / std def inv_standard(x, mean: float, std: float): """Inverse standard scaling, also unpacks torch tensors if needed