hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
1 - one_iteration_prediction
- ) * sample_x_t_minus_1_given_x_0_x_t(
- hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it)
- )
+ ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
if hat_x_t_minus_1.equal(x_t):
# log_string(f"exit after {it+1} iterations")
return (-loss).exp()
-def model_ae_argmax_nb_disagreements(model, input):
+def model_ae_argmax_nb_mistakes(model, input):
record = []
for x_0 in input.split(args.batch_size):
- nb_disagreements = 0
+ nb_mistakes = 0
for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
mask_generate = quiz_machine.make_quiz_mask(
quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
predicted = logits.argmax(dim=-1)
- nb_disagreements = nb_disagreements + (
+ nb_mistakes = nb_mistakes + (
mask_generate * predicted != mask_generate * x_0
).long().sum(dim=1)
- record.append(nb_disagreements)
+ record.append(nb_mistakes)
return torch.cat(record, dim=0)
######################################################################
-def c_quiz_criterion_one_good_one_bad(probas):
- return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
-
-
-def c_quiz_criterion_one_good_no_very_bad(probas):
- return (
- (probas.max(dim=1).values >= 0.75)
- & (probas.min(dim=1).values <= 0.75)
- & (probas.min(dim=1).values >= 0.25)
- )
-
-
-def c_quiz_criterion_diff(probas):
- return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
+def quiz_validation(models, c_quizzes, local_device):
+ nb_have_to_be_correct = args.nb_models // 2
+ nb_have_to_be_wrong = args.nb_models // 5
+ nb_runs = 3
+ nb_mistakes_to_be_wrong = 5
-def c_quiz_criterion_diff2(probas):
- v = probas.sort(dim=1).values
- return (v[:, -2] - v[:, 0]) >= 0.5
+ record_wrong = []
+ nb_correct, nb_wrong = 0, 0
+ for i, model in enumerate(models):
+ assert i == model.id # a bit of paranoia
+ model = copy.deepcopy(model).to(local_device).eval()
+ correct, wrong = True, False
+ for _ in range(nb_runs):
+ n = model_ae_argmax_nb_mistakes(model, c_quizzes).long()
+ correct = correct & (n == 0)
+ wrong = wrong | (n >= nb_mistakes_to_be_wrong)
+ record_wrong.append(wrong[:, None])
+ nb_correct += correct.long()
+ nb_wrong += wrong.long()
-def c_quiz_criterion_few_good_one_bad(probas):
- v = probas.sort(dim=1).values
- return (v[:, 0] <= 0.25) & (v[:, -3] >= 0.5)
+ # print("nb_correct", nb_correct)
+ # print("nb_wrong", nb_wrong)
-def c_quiz_criterion_two_good(probas):
- return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2)
+ to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
+ wrong = torch.cat(record_wrong, dim=1)
-def c_quiz_criterion_some(probas):
- return ((probas >= 0.8).long().sum(dim=1) >= 1) & (
- (probas <= 0.2).long().sum(dim=1) >= 1
- )
+ return to_keep, wrong
def generate_ae_c_quizzes(models, nb, local_device=main_device):
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- # p = [
- # model_ae_proba_solutions(model, c_quizzes)[:, None]
- # for model in models
- # ]
-
- # probas = torch.cat(p, dim=1)
- # to_keep = c_quiz_criterion_two_good(probas)
-
- nb_disagreements = []
- for i, model in enumerate(models):
- assert i == model.id # a bit of paranoia
- model = copy_for_inference(model)
- nb_disagreements.append(
- model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
- :, None
- ]
- )
- nb_disagreements = torch.cat(nb_disagreements, dim=1)
-
- v = nb_disagreements.sort(dim=1).values
- to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4)
-
+ to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
q = c_quizzes[to_keep]
if q.size(0) > 0:
record_c_quizzes.append(q)
- a = (nb_disagreements == 0)[to_keep]
+ a = (record_wrong == False)[to_keep]
record_agreements.append(a)
nb_c_quizzes_per_model += a.long().sum(dim=0)
subset_c_quizzes = c_quizzes[:nb_to_save]
- # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# for model in models:
- # model = copy_for_inference(model)
- # prediction = model_ae_argmax_predictions(model, subset_c_quizzes)
- # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png"
+ # for r in range(3):
+ # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png"
+ # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes)
# quiz_machine.problem.save_quizzes_as_image(
# args.result_dir,
# filename,
- # quizzes=prediction,
+ # quizzes=p,
+ # delta=True,
# nrow=8,
# )
# log_string(f"wrote {filename}")
- # exit(0)
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
filename = f"culture_c_quiz_{n_epoch:04d}.png"
- # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
-
l = [
model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
for model in models
args.result_dir,
filename,
quizzes=subset_c_quizzes,
- # predicted_parts=predicted_parts,
- # correct_parts=correct_parts,
comments=comments,
delta=True,
nrow=8,
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
c_quizzes = state["c_quizzes"]
- # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
- # total_time_training_models = state["total_time_training_models"]
- # common_c_quiz_bags = state["common_c_quiz_bags"]
except FileNotFoundError:
log_string(f"cannot find {filename}")
pass
state = {
"current_epoch": n_epoch,
"c_quizzes": c_quizzes,
- # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
- # "total_time_training_models": total_time_training_models,
- # "common_c_quiz_bags": common_c_quiz_bags,
}
filename = "state.pth"
# --------------------------------------------------------------------
- # run_ae_test(
- # model,
- # alien_quiz_machine,
- # n_epoch,
- # c_quizzes=None,
- # local_device=main_device,
- # prefix="alien",
- # )
-
- # exit(0)
-
- # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device)
- # exit(0)
-
log_string(f"{time_train=} {time_c_quizzes=}")
if (
min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes
and time_train >= time_c_quizzes
):
- if c_quizzes is not None:
- save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
+ if c_quizzes is None:
+ for model in models:
+ filename = f"ae_{model.id:03d}_naive.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
+ )
+
+ log_string(f"wrote {filename}")
+
+ # --------------------------------------------------------------------
last_n_epoch_c_quizzes = n_epoch
nb_gpus = len(gpus)
c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
- print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}")
-
# --------------------------------------------------------------------
log_string(f"generated_c_quizzes {c_quizzes.size()=}")
for model in models:
model.test_accuracy = 0
- save_badness_statistics(n_epoch, models, c_quizzes, "before")
-
if c_quizzes is None:
log_string("no_c_quiz")
else:
threads = []
- # for model in models:
- # log_string(f"DEBUG {model.id} {sum([ p.sum() for p in model.parameters()]).item()}")
-
start_time = time.perf_counter()
for gpu, model in zip(gpus, weakest_models):
"state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"test_accuracy": model.test_accuracy,
- # "gen_test_accuracy": model.gen_test_accuracy,
- # "gen_state_dict": model.gen_state_dict,
- # "train_c_quiz_bags": model.train_c_quiz_bags,
- # "test_c_quiz_bags": model.test_c_quiz_bags,
},
os.path.join(args.result_dir, filename),
)