parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
+
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--lambda_H", type=float, default=0.0)
nb_train_samples, acc_train_loss = 0, 0.0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, model.train_c_quiz_bags
+ args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
)
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
######################################################################
if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 1000
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 1000
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
log_string(
f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
log_string(f"current_best_test_accuracies {cta}")
##################################################
- # If all the models are good enough, generate new quizzes and
- # re-compute the test errors
for model in models:
if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
)
model.best_dict = copy.deepcopy(model.state_dict())
model.best_test_accuracy = model.test_accuracy
- model.test_accuracy = 0.0
# we restart
if total_time_generating_c_quizzes == 0:
# Select, improve, and eval the worst model(s)
if total_time_training_models <= total_time_generating_c_quizzes:
- ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ ranked_models = sorted(
+ models,
+ # This ugly recipe will pick the worst if there some below
+ # args.accuracy_to_make_c_quizzes or one at random if they
+ # are all above
+ key=lambda m: float(
+ m.test_accuracy
+ if m.test_accuracy < args.accuracy_to_make_c_quizzes
+ else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
+ ),
+ )
weakest_models = ranked_models[: len(gpus)]
######################################################################
- def data_input(self, nb_samples, c_quiz_bags):
+ def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1):
if len(c_quiz_bags) > 0:
c_quizzes = torch.cat(c_quiz_bags, dim=0)
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[
+ : nb_samples // 2 - body.size(0)
+ ]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
+
if c_quizzes.size(0) > nb_samples // 2:
i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
c_quizzes = c_quizzes[i]