parser.add_argument("--inference_batch_size", type=int, default=25)
-parser.add_argument("--nb_train_samples", type=int, default=40000)
+parser.add_argument("--nb_train_samples", type=int, default=25000)
parser.add_argument("--nb_test_samples", type=int, default=1000)
parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
-parser.add_argument("--c_quiz_multiplier", type=int, default=4)
+parser.add_argument("--c_quiz_multiplier", type=int, default=10)
parser.add_argument("--learning_rate", type=float, default=5e-4)
mask_generate = quiz_machine.make_quiz_mask(
quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- targets, logits = targets_and_prediction(model, q, mask_generate)
+ targets, logits = targets_and_prediction(
+ model, q, mask_generate, prompt_noise=args.prompt_noise
+ )
loss_per_token = F.cross_entropy(
logits.transpose(1, 2), targets, reduction="none"
)
log_string(f"wrote {filename}")
-def generate_ae_c_quizzes(models, local_device=main_device):
+def generate_ae_c_quizzes(models, nb, local_device=main_device):
criteria = [
# c_quiz_criterion_only_one,
c_quiz_criterion_one_good_one_bad,
# c_quiz_criterion_some,
]
- for m in models:
- m.eval().to(local_device)
+ # To be thread-safe we must make copies
+ models = [copy.deepcopy(model).to(local_device) for model in models]
quad_order = ("A", "f_A", "B", "f_B")
duration_max = 4 * 3600
- wanted_nb = args.nb_train_samples // args.c_quiz_multiplier
+ wanted_nb = nb
nb_to_save = 256
with torch.autograd.no_grad():
return torch.cat(a, dim=0).unique(dim=0)
+def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
+ record.append(generate_ae_c_quizzes(models, nb, local_device))
+
+
######################################################################
current_epoch = 0
save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
last_n_epoch_c_quizzes = n_epoch
+ nb_c_quizzes_to_generate = args.nb_train_samples // args.c_quiz_multiplier
+
+ # --------------------------------------------------------------------
+
+ records, threads = [], []
+
start_time = time.perf_counter()
- c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)
+
+ for gpu in gpus:
+ t = threading.Thread(
+ target=thread_generate_ae_c_quizzes,
+ daemon=True,
+ args=(models, nb_c_quizzes_to_generate, records, gpu),
+ )
+
+ # To get a different sequence between threads
+ log_string(f"dummy {torch.rand(1)}")
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
time_c_quizzes = time.perf_counter() - start_time
+
+ c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0)
+
+ # --------------------------------------------------------------------
+
+ log_string(f"generated_c_quizzes {c_quizzes.size()=}")
+
time_train = 0
for model in models:
model.test_accuracy = 0
else:
log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
+ # --------------------------------------------------------------------
+
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]