+
+######################################################################
+
+for n_epoch in range(args.nb_epochs):
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+ cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
+
+ ##################################################
+ # Select, improve, and eval the worst model
+
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+
+ weakest_models = ranked_models[: args.nb_gpus]
+
+ threads = []
+
+ for gpu_id, model in enumerate(weakest_models):
+ log_string(f"training model {model.id}")
+
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, f"cuda:{gpu_id}")
+ )
+
+ threads.append(t)
+
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ ##################################################
+ # Replace a fraction of the w_quizzes with fresh ones