-if args.learning_rate_schedule == "cos":
- learning_rate_schedule = {}
- for n_epoch in range(args.nb_epochs):
- u = n_epoch / args.nb_epochs * math.pi
- learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
-else:
- u = {
- int(k): float(v)
- for k, v in [
- tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+ if args.dirty_debug:
+ nb_correct = torch.randint(
+ len(models) + 1, nb_correct.size(), device=new_c_quizzes.device
+ )
+
+ for n in range(nb_correct.max() + 1):
+ recorded[n].append(new_c_quizzes[nb_correct == n].clone())
+
+ log_string(
+ f"keep c_quizzes {nb_validated()*100/nb_generated():.02f}% kept total {nb_validated()} / {nb_to_create}"
+ )
+
+ # concatenate and shuffle
+ for n in recorded.keys():
+ if len(recorded[n]) > 0:
+ q = torch.cat(recorded[n], dim=0)
+ q = q[torch.randperm(q.size(0), device=q.device)]
+ recorded[n] = q
+ else:
+ del recorded[n]
+
+ new_c_quizzes = torch.cat(
+ [recorded[n] for n in range(args.min_to_validate, args.max_to_validate + 1)],
+ dim=0,
+ )
+
+ new_c_quizzes = new_c_quizzes[
+ torch.randperm(new_c_quizzes.size(0), device=new_c_quizzes.device)[
+ : nb_for_train + nb_for_test