From f8c0c61af2db75655cdcefde1c949d669f5f96fb Mon Sep 17 00:00:00 2001 From: Michel Spils <msp@informatik.uni-kiel.de> Date: Thu, 12 Dec 2024 14:47:26 +0000 Subject: [PATCH] minor bugs --- notebooks | 1 + src/scripts/extract_metrics.py | 17 ++++++++++------- src/utils/helpers.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) create mode 160000 notebooks diff --git a/notebooks b/notebooks new file mode 160000 index 0000000..abdb069 --- /dev/null +++ b/notebooks @@ -0,0 +1 @@ +Subproject commit abdb06965783302b8d5dd5d7f8bdb8e3cc9d639c diff --git a/src/scripts/extract_metrics.py b/src/scripts/extract_metrics.py index 5f170d4..fb77b27 100644 --- a/src/scripts/extract_metrics.py +++ b/src/scripts/extract_metrics.py @@ -28,13 +28,16 @@ def extract_summaries_to_csv(root_dir,metrics=None,force=False): print(f"Skipping {root} as metrics.csv already exists. Use --force to overwrite") continue #log_dir = "../../../data-project/KIWaVo/models/lfu/willenscharen6/lightning_logs/version_0/" - reader_hp = SummaryReader(root,pivot=True) - df =reader_hp.scalars - df = df.drop(["step","epoch"],axis=1)[1:49] - df = df[metrics] - df.columns = df.columns.str.slice(3) - df.to_csv(root + "/metrics.csv") - print(f"Extracted metrics for {root}") + try : + reader_hp = SummaryReader(root,pivot=True) + df =reader_hp.scalars + df = df.drop(["step","epoch"],axis=1)[1:49] + df = df[metrics] + df.columns = df.columns.str.slice(3) + df.to_csv(root + "/metrics.csv") + print(f"Extracted metrics for {root}") + except Exception as e: + print(f"Error in {root}") def main(): diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 56bf836..f87f63c 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -88,7 +88,7 @@ def get_pred( y_pred = np.concatenate(pred) if y_true is not None: - y_pred = np.concatenate([np.expand_dims(y_true, axis=1), pred], axis=1) + y_pred = np.concatenate([np.expand_dims(y_true, axis=1), y_pred], axis=1) # Replaced pred with y_pred, why was it pred? y_pred = pd.DataFrame(y_pred, index=y_true.index, columns=range(y_pred.shape[1])) return y_pred -- GitLab