+ quiz_machine.reverse_random_half_in_place(q)
+
+ if q.size(0) > 0:
+ quiz_machine.save_quizzes(
+ args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q
+ )
+
+
+######################################################################
+
+models = []
+
+for k in range(args.nb_gpts):
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+ ).to(device)
+
+ model.main_test_accuracy = 0.0
+ model.id = k
+
+ models.append(model)
+
+
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+
+log_string(
+ f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+)
+
+######################################################################
+
+if args.dirty_debug:
+ args.accuracy_to_make_c_quizzes = 0.0
+ args.nb_gpts = 2
+ nb_new_c_quizzes_for_train = 100
+ nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
+for n_epoch in range(args.nb_epochs):
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+ cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
+
+ ##################################################
+ # Select, improve, and eval the worst model
+
+ weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+
+ log_string(
+ f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
+ )
+
+ one_epoch(weakest_model, quiz_machine)
+
+ log_string(
+ f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
+ )
+
+ run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+
+ log_string(
+ f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
+ )
+
+ ##################################################
+ # Replace a fraction of the w_quizzes with fresh ones
+
+ quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+
+ ##################################################
+ # If all the models are good enough, generate new quizzes and
+ # re-compute the test errors
+
+ if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ create_c_quizzes(
+ models,
+ quiz_machine,
+ nb_for_train=nb_new_c_quizzes_for_train,
+ nb_for_test=nb_new_c_quizzes_for_test,
+ )
+
+ for model in models:
+ run_tests(model, quiz_machine, deterministic_synthesis=False)