diff --git a/src/scripts/extract_metrics.py b/src/scripts/extract_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5f170d4042f18d32d1782cd3392b936d0b315aba --- /dev/null +++ b/src/scripts/extract_metrics.py @@ -0,0 +1,54 @@ +import argparse +from ast import arg +from tbparse import SummaryReader +import os +from pathlib import Path + + + + +def extract_summaries_to_csv(root_dir,metrics=None,force=False): + """Takes a root directory and extracts metrics from the tensorboard file into a metrics.csv file. + Tensorboard files look something like events.out.tfevents.1701686993.kiel.2222904.0 + + default metrics are "hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood" + + + Args: + root_dir (str/path): Directory with folders containing tensorboard files and models. Those are expected to beging with version + metrics (list[str], optional): List of metrics to extract. Defaults to None. + """ + + if metrics is None: + metrics = ["hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood"] + + for root, _, files in os.walk(root_dir): + if root.split("/")[-1].startswith("version"): + if "metrics.csv" in files and not force: + print(f"Skipping {root} as metrics.csv already exists. Use --force to overwrite") + continue + #log_dir = "../../../data-project/KIWaVo/models/lfu/willenscharen6/lightning_logs/version_0/" + reader_hp = SummaryReader(root,pivot=True) + df =reader_hp.scalars + df = df.drop(["step","epoch"],axis=1)[1:49] + df = df[metrics] + df.columns = df.columns.str.slice(3) + df.to_csv(root + "/metrics.csv") + print(f"Extracted metrics for {root}") + + +def main(): + parser = argparse.ArgumentParser("Extract metrics from tensorboard files to csv. Will check folders that start with 'version'") + parser.add_argument("root_dir", type=str, help="Root directory with tensorboard files") + parser.add_argument("--metrics", type=str, help="""Metrics to extract. + Default: "hp/test_p10","hp/test_p20","hp/test_rmse","hp/test_p10_flood","hp/test_p20_flood","hp/test_rmse_flood" + It is probably easiert to modify the source code if you want to change the metrics. + """, nargs="+") + parser.add_argument("--force", action="store_true", help="Force overwrite of existing metrics.csv files") + + args = parser.parse_args() + print(args) + extract_summaries_to_csv(args.root_dir,args.metrics,args.force) + +if __name__ == "__main__": + main() \ No newline at end of file