From e6354ce8c44df03b59d06dc4702fd23ee7086223 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 21 Sep 2024 05:12:48 +0200 Subject: [PATCH] Update. --- grids.py | 3 +- main.py | 109 ++++++++++++++++++++++++++++++------------------------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/grids.py b/grids.py index e5890ca..78d9297 100755 --- a/grids.py +++ b/grids.py @@ -287,7 +287,8 @@ class Grids(problem.Problem): ###################################################################### def vocabulary_size(self): - warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + # warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + # return self.nb_colors+4 return self.nb_colors def grid2img(self, x, scale=15, grids=True): diff --git a/main.py b/main.py index 961ae81..21666d1 100755 --- a/main.py +++ b/main.py @@ -47,7 +47,7 @@ parser.add_argument("--eval_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=50000) -parser.add_argument("--nb_test_samples", type=int, default=1000) +parser.add_argument("--nb_test_samples", type=int, default=2500) parser.add_argument("--nb_c_quizzes", type=int, default=5000) @@ -252,6 +252,25 @@ assert args.nb_test_samples % args.batch_size == 0 ###################################################################### +def optimizer_to(optim, device): + """Move the optimizer optim to the device""" + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + + +###################################################################### + + def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): if c_quizzes is None: quizzes = problem.generate_w_quizzes(nb_samples) @@ -290,25 +309,6 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): ###################################################################### -def optimizer_to(optim, device): - """Move the optimizer optim to the device""" - for param in optim.state.values(): - # Not sure there are any global tensors in the state dict - if isinstance(param, torch.Tensor): - param.data = param.data.to(device) - if param._grad is not None: - param._grad.data = param._grad.data.to(device) - elif isinstance(param, dict): - for subparam in param.values(): - if isinstance(subparam, torch.Tensor): - subparam.data = subparam.data.to(device) - if subparam._grad is not None: - subparam._grad.data = subparam._grad.data.to(device) - - -###################################################################### - - def add_hints_imt(imt_set): """Set every component of the mask to zero with probability args.proba_hint, and for each component set to zero, copy the @@ -589,10 +589,12 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de ###################################################################### -def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): - one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device) +def one_complete_epoch( + model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device +): + one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device) - one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device) + one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device) # Compute the test accuracy @@ -627,7 +629,7 @@ def max_nb_mistakes_on_one_grid(quizzes, prediction): ) -def evaluate_quizzes(quizzes, models, local_device): +def evaluate_quizzes(quizzes, models, with_hints, local_device): nb_correct, nb_wrong = 0, 0 for model in models: @@ -636,20 +638,16 @@ def evaluate_quizzes(quizzes, models, local_device): model=model, input=quizzes, with_noise=False, - with_hints=True, + with_hints=with_hints, local_device=local_device, ) nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted) nb_correct += (nb_mistakes == 0).long() nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long() - to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( - nb_wrong >= args.nb_have_to_be_wrong - ) - # print("\n\n", nb_correct, nb_wrong) - return to_keep, nb_correct, nb_wrong + return nb_correct, nb_wrong ###################################################################### @@ -686,12 +684,17 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): # Select the ones that are solved properly by some models and # not understood by others - to_keep, nb_correct, nb_wrong = evaluate_quizzes( + nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, + with_hints=True, local_device=local_device, ) + to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( + nb_wrong >= args.nb_have_to_be_wrong + ) + nb_validated += to_keep.long().sum().item() record.append(c_quizzes[to_keep]) @@ -743,8 +746,8 @@ def multithread_execution(fun, arguments): for args in arguments: # To get a different sequence between threads - # log_string(f"dummy_rand {torch.rand(1)}") - torch.rand(1) + log_string(f"dummy_rand {torch.rand(1)}") + # torch.rand(1) t = threading.Thread(target=threadable_fun, daemon=True, args=args) threads.append(t) t.start() @@ -787,9 +790,10 @@ def save_models(models, suffix=""): def save_quiz_image(models, c_quizzes, filename, local_device=main_device): c_quizzes = c_quizzes.to(local_device) - to_keep, nb_correct, nb_wrong = evaluate_quizzes( + nb_correct, nb_wrong = evaluate_quizzes( quizzes=c_quizzes, models=models, + with_hints=False, local_device=local_device, ) @@ -873,10 +877,6 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] - # model.gen_test_accuracy = d["gen_test_accuracy"] - # model.gen_state_dict = d["gen_state_dict"] - # model.train_c_quiz_bags = d["train_c_quiz_bags"] - # model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") filename = "state.pth" @@ -889,7 +889,8 @@ if args.resume: log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] - c_quizzes = state["c_quizzes"] + train_c_quizzes = state["train_c_quizzes"] + test_c_quizzes = state["test_c_quizzes"] ###################################################################### @@ -899,7 +900,7 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -main_c_quizzes = None +train_c_quizzes, test_c_quizzes = None, None ###################################################################### @@ -908,7 +909,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, - "main_c_quizzes": main_c_quizzes, + "train_c_quizzes": train_c_quizzes, + "test_c_quizzes": test_c_quizzes, } filename = "state.pth" @@ -925,7 +927,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): lowest_test_accuracy = min([float(m.test_accuracy) for m in models]) if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes: - if main_c_quizzes is None: + if train_c_quizzes is None: save_models(models, "naive") nb_gpus = len(gpus) @@ -942,20 +944,29 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"generated_c_quizzes {new_c_quizzes.size()}") - main_c_quizzes = ( + train_c_quizzes = ( new_c_quizzes - if main_c_quizzes is None - else torch.cat([main_c_quizzes, new_c_quizzes]) + if train_c_quizzes is None + else torch.cat([train_c_quizzes, new_c_quizzes]) ) - main_c_quizzes = main_c_quizzes[-args.nb_train_samples :] + train_c_quizzes = train_c_quizzes[-args.nb_train_samples :] + + nb_correct, _ = evaluate_quizzes( + quizzes=train_c_quizzes, + models=models, + with_hints=False, + local_device=local_device, + ) + + test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct] for model in models: model.test_accuracy = 0 - if main_c_quizzes is None: + if train_c_quizzes is None: log_string("no_c_quiz") else: - log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}") + log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}") # -------------------------------------------------------------------- @@ -969,7 +980,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): multithread_execution( one_complete_epoch, [ - (model, n_epoch, main_c_quizzes, gpu) + (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus) ], ) -- 2.39.5