+ filename = "c_quizzes.pth"
+ quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
+ ##################################################
+ # Select, improve, and eval the worst model
+
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+
+ weakest_models = ranked_models[: len(gpus)]
+
+ threads = []
+
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
+
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ )
+
+ threads.append(t)
+
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ # Save the models to disk
+
+ for model in weakest_models:
+ filename = f"gpt_{model.id:03d}.pth"
+ torch.save(
+ (model.state_dict(), model.main_test_accuracy),
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
+ # Renew the training samples
+
+ for model in weakest_models:
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples)