From 7ddb94af6ccc6f3e6e7067cde135f25fe168a756 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 25 Sep 2024 09:12:39 +0200 Subject: [PATCH] Update. --- main.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 87e73ad..230453f 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,11 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_prompt_noise", type=float, default=0.05) -parser.add_argument("--proba_hints", type=float, default=0.1) +parser.add_argument("--proba_hints", type=float, default=0.25) + +# ---------------------------------- + +parser.add_argument("--test", type=str, default=None) ###################################################################### @@ -971,6 +975,32 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") # model.test_accuracy = 0 #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +if args.test == "params": + for model in models: + filename = f"ae_{model.id:03d}_naive.pth" + + d = torch.load( + os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, + ) + + early_model = new_model() + early_model.load_state_dict(d["state_dict"]) + early_model.optimizer.load_state_dict(d["optimizer_state_dict"]) + early_model.test_accuracy = d["test_accuracy"] + early_model.nb_epochs = d["nb_epochs"] + + log_string(f"successfully loaded {filename}") + + print(f"-- {model.id} -------------------------------") + for ep, p in zip(early_model.parameters(), model.parameters()): + print(f"mean {model.id} {ep.mean()} {p.mean()} std {ep.std()} {p.std()}") + + exit(0) + +###################################################################### + for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() -- 2.39.5