- # concatenate and shuffle
- for n in recorded.keys():
- if len(recorded[n]) > 0:
- q = torch.cat(recorded[n], dim=0)
- q = q[torch.randperm(q.size(0), device=q.device)]
- recorded[n] = q
- else:
- del recorded[n]
-
- new_c_quizzes = recorded[nb_correct_to_validate][: nb_for_train + nb_for_test]
-
- quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
- quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
-
- for n in recorded.keys():
- s = "_validated" if n == nb_correct_to_validate else ""
- quizz_machine.problem.save_quizzes(
- recorded[n][:72],
- args.result_dir,
- f"culture_c_quiz_{n_epoch:04d}_N{n}{s}",
+ quiz_machine.reverse_random_half_in_place(new_c_quizzes)
+
+ quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+ quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+
+ # save a bunch of images to investigate what quizzes with a
+ # certain nb of correct predictions look like
+
+ for n in range(len(models) + 1):
+ s = (
+ "_validated"
+ if n >= args.min_to_validate and n <= args.max_to_validate
+ else ""