Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 07:12:39 +0000 (09:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 25 Sep 2024 07:12:39 +0000 (09:12 +0200)
main.py

diff --git a/main.py b/main.py
index 87e73ad..230453f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,7 +101,11 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
 
-parser.add_argument("--proba_hints", type=float, default=0.1)
+parser.add_argument("--proba_hints", type=float, default=0.25)
+
+# ----------------------------------
+
+parser.add_argument("--test", type=str, default=None)
 
 ######################################################################
 
@@ -971,6 +975,32 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 # model.test_accuracy = 0
 #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
+if args.test == "params":
+    for model in models:
+        filename = f"ae_{model.id:03d}_naive.pth"
+
+        d = torch.load(
+            os.path.join(args.result_dir, filename),
+            map_location="cpu",
+            weights_only=False,
+        )
+
+        early_model = new_model()
+        early_model.load_state_dict(d["state_dict"])
+        early_model.optimizer.load_state_dict(d["optimizer_state_dict"])
+        early_model.test_accuracy = d["test_accuracy"]
+        early_model.nb_epochs = d["nb_epochs"]
+
+        log_string(f"successfully loaded {filename}")
+
+        print(f"-- {model.id} -------------------------------")
+        for ep, p in zip(early_model.parameters(), model.parameters()):
+            print(f"mean {model.id} {ep.mean()} {p.mean()} std {ep.std()} {p.std()}")
+
+    exit(0)
+
+######################################################################
+
 for n_epoch in range(current_epoch, args.nb_epochs):
     start_time = time.perf_counter()