From da145649c25f9e3f63aa132ef79b044ce72fd460 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 18:40:29 +0200 Subject: [PATCH] Update. --- main.py | 131 +++++++++++++++++++------------------------------------- 1 file changed, 45 insertions(+), 86 deletions(-) diff --git a/main.py b/main.py index c4ecc49..71adf30 100755 --- a/main.py +++ b/main.py @@ -42,8 +42,6 @@ parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) -parser.add_argument("--log_command", type=str, default=None) - # ---------------------------------- parser.add_argument("--nb_epochs", type=int, default=10000) @@ -58,23 +56,17 @@ parser.add_argument("--nb_train_samples", type=int, default=50000) parser.add_argument("--nb_test_samples", type=int, default=1000) -parser.add_argument("--nb_train_alien_samples", type=int, default=0) - -parser.add_argument("--nb_test_alien_samples", type=int, default=0) - -parser.add_argument("--nb_c_quizzes", type=int, default=2500) +parser.add_argument("--nb_c_quizzes", type=int, default=10000) parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) -parser.add_argument("--reboot", action="store_true", default=False) - parser.add_argument("--nb_have_to_be_correct", type=int, default=3) parser.add_argument("--nb_have_to_be_wrong", type=int, default=1) -parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) +parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10) # ---------------------------------- @@ -94,10 +86,6 @@ parser.add_argument("--dropout", type=float, default=0.5) # ---------------------------------- -parser.add_argument("--deterministic_synthesis", action="store_true", default=False) - -parser.add_argument("--problem", type=str, default="grids") - parser.add_argument("--nb_threads", type=int, default=1) parser.add_argument("--gpus", type=str, default="all") @@ -110,20 +98,12 @@ parser.add_argument("--diffusion_nb_iterations", type=int, default=25) parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05) -parser.add_argument("--min_succeed_to_validate", type=int, default=2) - parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_prompt_noise", type=float, default=0.05) parser.add_argument("--proba_hint", type=float, default=0.01) -# parser.add_argument("--nb_hints", type=int, default=25) - -parser.add_argument("--nb_runs", type=int, default=1) - -parser.add_argument("--test", type=str, default=None) - parser.add_argument("--quizzes", type=str, default=None) ###################################################################### @@ -141,18 +121,6 @@ parser.add_argument( ###################################################################### -parser.add_argument("--sky_height", type=int, default=6) - -parser.add_argument("--sky_width", type=int, default=8) - -parser.add_argument("--sky_nb_birds", type=int, default=3) - -parser.add_argument("--sky_nb_iterations", type=int, default=2) - -parser.add_argument("--sky_speed", type=int, default=3) - -###################################################################### - args = parser.parse_args() if args.result_dir is None: @@ -358,7 +326,7 @@ def optimizer_to(optim, device): # values from the target to the input -def add_hints(imt_set): +def add_hints_imt(imt_set): input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] # h = torch.rand(masks.size(), device=masks.device) - masks # t = h.sort(dim=1).values[:, args.nb_hints, None] @@ -375,7 +343,7 @@ def add_hints(imt_set): # args.proba_prompt_noise -def add_noise(imt_set): +def add_noise_imt(imt_set): input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] noise = quiz_machine.pure_noise(input.size(0), input.device) change = (1 - masks) * ( @@ -443,8 +411,8 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) if with_perturbations: - imt_set = add_hints(imt_set) - imt_set = add_noise(imt_set) + imt_set = add_hints_imt(imt_set) + imt_set = add_noise_imt(imt_set) result = ae_predict(model, imt_set, local_device=local_device, desc=None) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) @@ -542,11 +510,17 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): ) q_p, q_g = quizzes.to(local_device).chunk(2) + + # Half of the samples train the prediction, and we inject noise in + # all, and hints in half b_p = batch_for_prediction_imt(q_p) i = torch.rand(b_p.size(0)) < 0.5 - b_p = add_noise(b_p) - b_p[i] = add_hints(b_p[i]) + b_p = add_noise_imt(b_p) + b_p[i] = add_hints_imt(b_p[i]) + + # The other half are denoising examples for the generation b_g = batch_for_generation_imt(q_g) + imt_set = torch.cat([b_p, b_g]) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] @@ -642,7 +616,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): model.test_accuracy = nb_correct / nb_total log_string( - f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)" + f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)" ) # Save some images of the ex nihilo generation of the four grids @@ -782,9 +756,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): ###################################################################### -def save_quiz_image( - models, c_quizzes, filename, solvable_only=False, local_device=main_device -): +def save_quiz_image(models, c_quizzes, filename, local_device=main_device): c_quizzes = c_quizzes.to(local_device) to_keep, nb_correct, nb_wrong = evaluate_quizzes( @@ -794,11 +766,6 @@ def save_quiz_image( local_device=local_device, ) - if solvable_only: - c_quizzes = c_quizzes[to_keep] - nb_correct = nb_correct[to_keep] - nb_wrong = nb_wrong[to_keep] - comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)] quiz_machine.problem.save_quizzes_as_image( @@ -821,29 +788,31 @@ if args.resume: for model in models: filename = f"ae_{model.id:03d}.pth" - try: - d = torch.load(os.path.join(args.result_dir, filename), map_location="cpu") - model.load_state_dict(d["state_dict"]) - model.optimizer.load_state_dict(d["optimizer_state_dict"]) - model.test_accuracy = d["test_accuracy"] - # model.gen_test_accuracy = d["gen_test_accuracy"] - # model.gen_state_dict = d["gen_state_dict"] - # model.train_c_quiz_bags = d["train_c_quiz_bags"] - # model.test_c_quiz_bags = d["test_c_quiz_bags"] - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass - - try: - filename = "state.pth" - state = torch.load(os.path.join(args.result_dir, filename)) + d = torch.load( + os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, + ) + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.test_accuracy = d["test_accuracy"] + # model.gen_test_accuracy = d["gen_test_accuracy"] + # model.gen_state_dict = d["gen_state_dict"] + # model.train_c_quiz_bags = d["train_c_quiz_bags"] + # model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") - current_epoch = state["current_epoch"] - c_quizzes = state["c_quizzes"] - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass + + filename = "state.pth" + state = torch.load( + os.path.join(args.result_dir, filename), + map_location="cpu", + weights_only=False, + ) + + log_string(f"successfully loaded {filename}") + + current_epoch = state["current_epoch"] + c_quizzes = state["c_quizzes"] ###################################################################### @@ -918,9 +887,8 @@ def multithread_execution(fun, arguments): for t in threads: t.join() - if records[0] is None: + if records[0] == (None,): return - else: return [ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) @@ -944,7 +912,8 @@ def save_models(models, suffix=""): }, os.path.join(args.result_dir, filename), ) - log_string(f"wrote {filename}") + + log_string(f"wrote ae_*{prefix}.pth") ###################################################################### @@ -983,20 +952,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) save_quiz_image( - models, - new_c_quizzes[:256], - f"culture_c_quiz_{n_epoch:04d}.png", - solvable_only=False, - ) - - save_quiz_image( - models, - new_c_quizzes[:256], - f"culture_c_quiz_{n_epoch:04d}_solvable.png", - solvable_only=True, + models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png" ) - log_string(f"generated_c_quizzes {new_c_quizzes.size()=}") + log_string(f"generated_c_quizzes {new_c_quizzes.size()}") c_quizzes = ( new_c_quizzes -- 2.39.5