parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
-parser.add_argument("--log_command", type=str, default=None)
-
# ----------------------------------
parser.add_argument("--nb_epochs", type=int, default=10000)
parser.add_argument("--nb_test_samples", type=int, default=1000)
-parser.add_argument("--nb_train_alien_samples", type=int, default=0)
-
-parser.add_argument("--nb_test_alien_samples", type=int, default=0)
-
-parser.add_argument("--nb_c_quizzes", type=int, default=2500)
+parser.add_argument("--nb_c_quizzes", type=int, default=10000)
parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=5e-4)
-parser.add_argument("--reboot", action="store_true", default=False)
-
parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
-parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10)
# ----------------------------------
# ----------------------------------
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="grids")
-
parser.add_argument("--nb_threads", type=int, default=1)
parser.add_argument("--gpus", type=str, default="all")
parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
-parser.add_argument("--min_succeed_to_validate", type=int, default=2)
-
parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
parser.add_argument("--proba_hint", type=float, default=0.01)
-# parser.add_argument("--nb_hints", type=int, default=25)
-
-parser.add_argument("--nb_runs", type=int, default=1)
-
-parser.add_argument("--test", type=str, default=None)
-
parser.add_argument("--quizzes", type=str, default=None)
######################################################################
######################################################################
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
-
-parser.add_argument("--sky_speed", type=int, default=3)
-
-######################################################################
-
args = parser.parse_args()
if args.result_dir is None:
# values from the target to the input
-def add_hints(imt_set):
+def add_hints_imt(imt_set):
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
# h = torch.rand(masks.size(), device=masks.device) - masks
# t = h.sort(dim=1).values[:, args.nb_hints, None]
# args.proba_prompt_noise
-def add_noise(imt_set):
+def add_noise_imt(imt_set):
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
noise = quiz_machine.pure_noise(input.size(0), input.device)
change = (1 - masks) * (
imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
if with_perturbations:
- imt_set = add_hints(imt_set)
- imt_set = add_noise(imt_set)
+ imt_set = add_hints_imt(imt_set)
+ imt_set = add_noise_imt(imt_set)
result = ae_predict(model, imt_set, local_device=local_device, desc=None)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
)
q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
b_p = batch_for_prediction_imt(q_p)
i = torch.rand(b_p.size(0)) < 0.5
- b_p = add_noise(b_p)
- b_p[i] = add_hints(b_p[i])
+ b_p = add_noise_imt(b_p)
+ b_p[i] = add_hints_imt(b_p[i])
+
+ # The other half are denoising examples for the generation
b_g = batch_for_generation_imt(q_g)
+
imt_set = torch.cat([b_p, b_g])
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
model.test_accuracy = nb_correct / nb_total
log_string(
- f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)"
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
)
# Save some images of the ex nihilo generation of the four grids
######################################################################
-def save_quiz_image(
- models, c_quizzes, filename, solvable_only=False, local_device=main_device
-):
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
c_quizzes = c_quizzes.to(local_device)
to_keep, nb_correct, nb_wrong = evaluate_quizzes(
local_device=local_device,
)
- if solvable_only:
- c_quizzes = c_quizzes[to_keep]
- nb_correct = nb_correct[to_keep]
- nb_wrong = nb_wrong[to_keep]
-
comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
quiz_machine.problem.save_quizzes_as_image(
for model in models:
filename = f"ae_{model.id:03d}.pth"
- try:
- d = torch.load(os.path.join(args.result_dir, filename), map_location="cpu")
- model.load_state_dict(d["state_dict"])
- model.optimizer.load_state_dict(d["optimizer_state_dict"])
- model.test_accuracy = d["test_accuracy"]
- # model.gen_test_accuracy = d["gen_test_accuracy"]
- # model.gen_state_dict = d["gen_state_dict"]
- # model.train_c_quiz_bags = d["train_c_quiz_bags"]
- # model.test_c_quiz_bags = d["test_c_quiz_bags"]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- try:
- filename = "state.pth"
- state = torch.load(os.path.join(args.result_dir, filename))
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ # model.gen_test_accuracy = d["gen_test_accuracy"]
+ # model.gen_state_dict = d["gen_state_dict"]
+ # model.train_c_quiz_bags = d["train_c_quiz_bags"]
+ # model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
- current_epoch = state["current_epoch"]
- c_quizzes = state["c_quizzes"]
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
+
+ filename = "state.pth"
+ state = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+
+ log_string(f"successfully loaded {filename}")
+
+ current_epoch = state["current_epoch"]
+ c_quizzes = state["c_quizzes"]
######################################################################
for t in threads:
t.join()
- if records[0] is None:
+ if records[0] == (None,):
return
-
else:
return [
torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
},
os.path.join(args.result_dir, filename),
)
- log_string(f"wrote {filename}")
+
+ log_string(f"wrote ae_*{prefix}.pth")
######################################################################
)
save_quiz_image(
- models,
- new_c_quizzes[:256],
- f"culture_c_quiz_{n_epoch:04d}.png",
- solvable_only=False,
- )
-
- save_quiz_image(
- models,
- new_c_quizzes[:256],
- f"culture_c_quiz_{n_epoch:04d}_solvable.png",
- solvable_only=True,
+ models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
)
- log_string(f"generated_c_quizzes {new_c_quizzes.size()=}")
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
c_quizzes = (
new_c_quizzes