+def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
+ nb_to_validate = nb_for_train + nb_for_test
+ nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
+ nb_validated = 0
+
+ recorded_validated = []
+
+ start_time = time.perf_counter()
+
+ nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
+
+ while nb_validated_per_model.sum() < nb_to_validate:
+ # We use the model that has generated the fewest quizzes to
+ # balance the number of quizzes per model overall
+
+ # model_for_generation = sorted(
+ # models, key=lambda m: nb_validated_per_model[m.id]
+ # )[0]
+
+ model_for_generation = models[torch.randint(len(models), (1,)).item()]
+
+ # We generate quizzes with a procedure that injects some
+ # structured noise
+
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ nb_to_generate_per_iteration,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ )
+
+ # We discard the trivial ones, according to a criterion
+ # specific to the world quizzes (e.g. B=f(B))
+
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+
+ c_quizzes = c_quizzes[to_keep]
+
+ # This is nb_quizzes x nb_models
+
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas = seq_logproba.exp()
+
+ nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
+ nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
+
+ to_keep = (
+ (nb_succeed + nb_fail == probas.size(1))
+ & (nb_fail >= 1)
+ & (nb_fail <= args.max_fail_to_validate)
+ )
+
+ c_quizzes = c_quizzes[to_keep]
+
+ if c_quizzes.size(0) > 0:
+ nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
+ recorded_validated.append(c_quizzes)
+ nb_validated = c_quizzes.size(0)
+ else:
+ nb_validated = 0
+
+ total_nb_validated = nb_validated_per_model.sum().item()
+
+ duration = time.perf_counter() - start_time
+
+ if total_nb_validated > 0:
+ if total_nb_validated < nb_to_validate:
+ d = (
+ (nb_to_validate - total_nb_validated)
+ * duration
+ / total_nb_validated
+ )
+ e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+ "%a %H:%M"
+ )
+ else:
+ e = "now!"
+ else:
+ e = "???"
+
+ log_string(
+ f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
+ )
+
+ validated_quizzes = torch.cat(recorded_validated, dim=0)
+
+ ######################################################################
+ # store the new c_quizzes which have been validated
+
+ v_train = validated_quizzes[:nb_for_train]
+ quiz_machine.store_c_quizzes(v_train, for_train=True)
+
+ v_test = validated_quizzes[nb_for_train:nb_to_validate]
+ quiz_machine.store_c_quizzes(v_test, for_train=False)
+
+ ######################################################################
+ # save images
+
+ vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
+
+ if vq.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas = seq_logproba.exp()
+
+ comments = []
+
+ for l in seq_logproba:
+ comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, vq, comments=comments
+ )