From 5a660807ce4a1d95c2a0123796685efea1f8dd2b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 18:29:42 +0200 Subject: [PATCH] Update. --- main.py | 63 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index e533802..b87518e 100755 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--inference_batch_size", type=int, default=25) +parser.add_argument("--inference_batch_size", type=int, default=50) parser.add_argument("--nb_train_samples", type=int, default=40000) @@ -61,7 +61,7 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) -parser.add_argument("--c_quiz_multiplier", type=int, default=1) +parser.add_argument("--c_quiz_multiplier", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=5e-4) @@ -973,7 +973,10 @@ def ae_batches( c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")] full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - nb, c_quiz_bags, data_structures=data_structures + nb, + c_quiz_bags, + data_structures=data_structures, + c_quiz_multiplier=args.c_quiz_multiplier, ) src = zip( @@ -1052,14 +1055,14 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50): update = (1 - mask_to_change) * input + mask_to_change * final if update.equal(input): - log_string(f"exit after {it+1} iterations") + # log_string(f"exit after {it+1} iterations") break else: changed = changed & (update != input).max(dim=1).values input[changed] = update[changed] - if it == nb_iterations_max: - log_string(f"remains {changed.long().sum()}") + # if it == nb_iterations_max: + # log_string(f"remains {changed.long().sum()}") return input @@ -1348,7 +1351,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): quad_order = ("A", "f_A", "B", "f_B") template = quiz_machine.problem.create_empty_quizzes( - nb=args.batch_size, quad_order=quad_order + nb=args.inference_batch_size, quad_order=quad_order ).to(local_device) mask_generate = quiz_machine.make_quiz_mask( @@ -1357,15 +1360,16 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - # wanted_nb = 240 - # nb_to_save = 240 - - wanted_nb = args.nb_train_samples // 4 + wanted_nb = 128 nb_to_save = 128 + # wanted_nb = args.nb_train_samples // args.c_quiz_multiplier + # nb_to_save = 256 + with torch.autograd.no_grad(): records = [[] for _ in criteria] + last_log = -1 start_time = time.perf_counter() while ( @@ -1374,8 +1378,6 @@ def generate_ae_c_quizzes(models, local_device=main_device): ): model = models[torch.randint(len(models), (1,)).item()] result = ae_generate(model, template, mask_generate) - bl = [bag_len(bag) for bag in records] - log_string(f"bag_len {bl} model {model.id}") to_keep = quiz_machine.problem.trivial(result) == False result = result[to_keep] @@ -1394,13 +1396,34 @@ def generate_ae_c_quizzes(models, local_device=main_device): if q.size(0) > 0: r.append(q) + duration = time.perf_counter() - start_time + nb_generated = min([bag_len(bag) for bag in records]) + + if last_log < 0 or duration > last_log + 60: + last_log = duration + if nb_generated > 0: + if nb_generated < wanted_nb: + d = (wanted_nb - nb_generated) * duration / nb_generated + e = ( + datetime.datetime.now() + datetime.timedelta(seconds=d) + ).strftime("%a %H:%M") + else: + e = "now!" + else: + e = "???" + + bl = [bag_len(bag) for bag in records] + log_string( + f"bag_len {bl} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" + ) + duration = time.perf_counter() - start_time log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h") for n, u in enumerate(records): quizzes = torch.cat(u, dim=0)[:nb_to_save] - filename = f"culture_c_{n_epoch:04d}_{n:02d}.png" + filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png" # result, predicted_parts, correct_parts = bag_to_tensors(record) @@ -1419,7 +1442,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): # correct_parts=correct_parts, comments=comments, delta=True, - nrow=12, + nrow=8, ) log_string(f"wrote {filename}") @@ -1475,6 +1498,9 @@ last_n_epoch_c_quizzes = 0 c_quizzes = None +time_c_quizzes = 0 +time_train = 0 + for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() @@ -1501,10 +1527,13 @@ for n_epoch in range(current_epoch, args.nb_epochs): if ( n_epoch >= 200 and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes - and n_epoch >= last_n_epoch_c_quizzes + 10 + and time_train >= time_c_quizzes ): last_n_epoch_c_quizzes = n_epoch + start_time = time.perf_counter() c_quizzes = generate_ae_c_quizzes(models, local_device=main_device) + time_c_quizzes = time.perf_counter() - start_time + time_train = 0 if c_quizzes is None: log_string("no_c_quiz") @@ -1534,6 +1563,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): for t in threads: t.join() + time_train += time.perf_counter() - start_time + # -------------------------------------------------------------------- for model in models: -- 2.39.5