From ece66035329cf86a322560779672d84652dd2a12 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 21 Jul 2024 00:45:18 +0200 Subject: [PATCH] Update. --- main.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 7588a50..653f5f5 100755 --- a/main.py +++ b/main.py @@ -472,11 +472,19 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 temperature_cold=args.temperature_cold, ) - recorded_too_simple.append( - keep_good_quizzes(models, c_quizzes, required_nb_failures=0) + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] + + nc = quiz_machine.solution_nb_correct(models, c_quizzes) + + count_nc = tuple( + n.item() for n in F.one_hot(nc, num_classes=len(models) + 1).sum(dim=0) ) - c_quizzes = keep_good_quizzes(models, c_quizzes) + log_string(f"nb_correct {count_nc}") + + recorded_too_simple.append(c_quizzes[nc == len(models)]) + + c_quizzes = c_quizzes[nc == len(models) - 1] nb_validated[model_for_generation.id] += c_quizzes.size(0) total_nb_validated = nb_validated.sum().item() @@ -517,7 +525,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 ###################################################################### # save images - vq = validated_quizzes[:128] + vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]] if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}" @@ -525,7 +533,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 args.result_dir, prefix, vq, show_part_to_predict=False ) - vq = too_simple_quizzes[:128] + vq = too_simple_quizzes if vq.size(0) > 0: prefix = f"culture_c_quiz_{n_epoch:04d}_too_simple" @@ -642,6 +650,11 @@ if args.dirty_debug: ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): + state = {"current_epoch": n_epoch} + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + log_string(f"--- epoch {n_epoch} ----------------------------------------") cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) @@ -700,11 +713,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") - state = {"current_epoch": n_epoch} - filename = "state.pth" - torch.save(state, os.path.join(args.result_dir, filename)) - log_string(f"wrote {filename}") - # Renew the training samples for model in weakest_models: -- 2.20.1