+ log_string(f"wrote {filename}")
+
+ ######################################################################
+
+ if science_w_quizzes is not None:
+ struct = ("A", "f_A", "B", "f_B")
+ mask = (0, 0, 0, 1)
+ result, correct = quiz_machine.predict(
+ model=model,
+ quizzes=science_w_quizzes.to(main_device),
+ struct=struct,
+ mask=mask,
+ )
+
+ predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
+ correct.size(0), -1
+ )
+ correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
+
+ nb_correct = (correct == 1).long().sum()
+ nb_total = (correct != 0).long().sum()
+
+ log_string(
+ f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
+ )
+
+ i = correct == 1
+ j = correct != 1
+
+ result = torch.cat([result[i], result[j]], dim=0)
+ correct = torch.cat([correct[i], correct[j]], dim=0)
+ correct_parts = predicted_parts * correct[:, None]
+
+ result = result[:128]
+ predicted_parts = predicted_parts[:128]
+ correct_parts = correct_parts[:128]
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
+ quizzes=result,
+ predicted_parts=predicted_parts,
+ correct_parts=correct_parts,
+ )
+
+
+######################################################################
+
+
+def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
+ nb_to_validate = nb_for_train + nb_for_test
+ nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
+ nb_validated = 0
+
+ recorded_validated = []
+
+ start_time = time.perf_counter()
+
+ nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
+
+ while nb_validated_per_model.sum() < nb_to_validate:
+ # We use the model that has generated the fewest quizzes to
+ # balance the number of quizzes per model overall
+
+ # model_for_generation = sorted(
+ # models, key=lambda m: nb_validated_per_model[m.id]
+ # )[0]
+
+ model_for_generation = models[torch.randint(len(models), (1,)).item()]
+
+ # We generate quizzes with a procedure that injects some
+ # structured noise
+
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ nb_to_generate_per_iteration,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ )
+
+ # We discard the trivial ones, according to a criterion
+ # specific to the world quizzes (e.g. B=f(B))
+
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+
+ c_quizzes = c_quizzes[to_keep]
+
+ # This is nb_quizzes x nb_models
+
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas = seq_logproba.exp()
+
+ nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
+ nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
+
+ to_keep = (
+ (nb_succeed + nb_fail == probas.size(1))
+ & (nb_fail >= 1)
+ & (nb_fail <= args.max_fail_to_validate)
+ )
+
+ c_quizzes = c_quizzes[to_keep]
+
+ if c_quizzes.size(0) > 0:
+ nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
+ recorded_validated.append(c_quizzes)
+ nb_validated = c_quizzes.size(0)
+ else:
+ nb_validated = 0
+
+ total_nb_validated = nb_validated_per_model.sum().item()
+
+ duration = time.perf_counter() - start_time
+
+ if total_nb_validated > 0:
+ if total_nb_validated < nb_to_validate:
+ d = (
+ (nb_to_validate - total_nb_validated)
+ * duration
+ / total_nb_validated
+ )
+ e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
+ "%a %H:%M"
+ )
+ else:
+ e = "now!"
+ else:
+ e = "???"
+
+ log_string(
+ f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
+ )
+
+ validated_quizzes = torch.cat(recorded_validated, dim=0)
+
+ ######################################################################
+ # store the new c_quizzes which have been validated
+
+ v_train = validated_quizzes[:nb_for_train]
+ quiz_machine.store_c_quizzes(v_train, for_train=True)
+
+ v_test = validated_quizzes[nb_for_train:nb_to_validate]
+ quiz_machine.store_c_quizzes(v_test, for_train=False)
+
+ ######################################################################
+ # save images
+
+ vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
+
+ if vq.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas = seq_logproba.exp()
+
+ comments = []
+
+ for l in seq_logproba:
+ comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, vq, comments=comments
+ )
+
+
+######################################################################
+
+# The generator is very similar to a "solving GPT" except that it
+# deals with quizzes prologued with one token per solving GPT that
+# indicates if the said model solves it or not.
+#
+# There are three levels of solving 0->proba<=proba_not_understands,
+# 2->proba>=proba_understands and 1 otherwise.
+
+
+def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
+ generator.to(main_device)
+
+ struct = ("A", "f_A", "B", "f_B")
+
+ c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
+ ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
+
+ i = F.one_hot(
+ torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
+ num_classes=args.nb_gpts,