nb,
data_structures,
local_device,
+ c_quizzes=None,
desc=None,
batch_size=args.batch_size,
):
+ c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
+
full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- nb, data_structures=data_structures
+ nb, c_quiz_bags, data_structures=data_structures
)
src = zip(
######################################################################
-def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_device):
+def one_ae_epoch(
+ model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
+):
model.train().to(local_device)
nb_train_samples, acc_train_loss = 0, 0.0
args.nb_train_samples,
data_structures,
local_device,
+ c_quizzes,
"training",
):
input = input.to(local_device)
def generate_ae_c_quizzes(models, local_device=main_device):
criteria = [
c_quiz_criterion_one_good_one_bad,
- c_quiz_criterion_diff,
- c_quiz_criterion_two_certains,
- c_quiz_criterion_some,
+ # c_quiz_criterion_diff,
+ # c_quiz_criterion_two_certains,
+ # c_quiz_criterion_some,
]
for m in models:
quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
)
- duration_max = 3600
+ duration_max = 4 * 3600
- wanted_nb = 512
+ wanted_nb = 10000
+ nb_to_save = 128
with torch.autograd.no_grad():
records = [[] for _ in criteria]
)
for n, u in enumerate(records):
- quizzes = torch.cat(u, dim=0)[:wanted_nb]
+ quizzes = torch.cat(u, dim=0)[:nb_to_save]
filename = f"culture_c_{n_epoch:04d}_{n:02d}.png"
# result, predicted_parts, correct_parts = bag_to_tensors(record)
# predicted_parts=predicted_parts,
# correct_parts=correct_parts,
comments=comments,
- nrow=8,
)
log_string(f"wrote {filename}")
+ a = [torch.cat(u, dim=0) for u in records]
+
+ return torch.cat(a, dim=0).unique(dim=0)
+
######################################################################
######################################################################
+last_n_epoch_c_quizzes = 0
+
+c_quizzes = None
+
for n_epoch in range(current_epoch, args.nb_epochs):
start_time = time.perf_counter()
# one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
# exit(0)
- if min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes:
- generate_ae_c_quizzes(models, local_device=main_device)
+ if (
+ min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+ and n_epoch >= last_n_epoch_c_quizzes + 10
+ ):
+ last_n_epoch_c_quizzes = n_epoch
+ c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)
+
+ if c_quizzes is None:
+ log_string("no_c_quiz")
+ else:
+ log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
t = threading.Thread(
target=one_ae_epoch,
daemon=True,
- args=(model, models, quiz_machine, n_epoch, gpu),
+ args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu),
)
threads.append(t)