parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--proba_hints", type=float, default=0.1)
+parser.add_argument("--proba_hints", type=float, default=0.25)
+
+# ----------------------------------
+
+parser.add_argument("--test", type=str, default=None)
######################################################################
# model.test_accuracy = 0
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+if args.test == "params":
+ for model in models:
+ filename = f"ae_{model.id:03d}_naive.pth"
+
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+
+ early_model = new_model()
+ early_model.load_state_dict(d["state_dict"])
+ early_model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ early_model.test_accuracy = d["test_accuracy"]
+ early_model.nb_epochs = d["nb_epochs"]
+
+ log_string(f"successfully loaded {filename}")
+
+ print(f"-- {model.id} -------------------------------")
+ for ep, p in zip(early_model.parameters(), model.parameters()):
+ print(f"mean {model.id} {ep.mean()} {p.mean()} std {ep.std()} {p.std()}")
+
+ exit(0)
+
+######################################################################
+
for n_epoch in range(current_epoch, args.nb_epochs):
start_time = time.perf_counter()