+ def standard_validity(logproba):
+ l = logproba.sort(dim=-1).values
+ return l[:, 0] < math.log(0.5)
+
+
+######################################################################
+
+for n_epoch in range(args.nb_epochs):
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+ cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
+
+ ##################################################
+ # Select, improve, and eval the worst model
+
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+
+ weakest_models = ranked_models[: len(gpus)]
+
+ threads = []
+
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
+
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ )
+
+ threads.append(t)
+
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ # Save the models to disk
+
+ for model in weakest_models:
+ filename = f"gpt_{model.id:03d}.pth"
+ torch.save(
+ (model.state_dict(), model.main_test_accuracy),
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
+ # Renew the training samples
+
+ for model in weakest_models:
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+
+ ##################################################
+ # If all the models are good enough, generate new quizzes and
+ # re-compute the test errors
+
+ if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ create_c_quizzes(