return torch.cat(record, dim=0)
+######################################################################
+
+
+def model_ae_argmax_predictions(model, input):
+ result = input.clone()
+ # result[...] = 0
+
+ for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)):
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ )
+ targets, logits = targets_and_prediction(
+ model, q, mask_generate, prompt_noise=args.prompt_noise
+ )
+
+ predicted = logits.argmax(dim=-1)
+
+ r[...] = (1 - mask_generate) * r + mask_generate * predicted
+
+ return result
+
+
+######################################################################
+
+
def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
def run_ae_test(
model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
):
- if prefix is not None:
+ if prefix is None:
+ prefix = ""
+ else:
prefix = prefix + "_"
with torch.autograd.no_grad():
wanted_nb = nb
nb_to_save = 256
+ nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
with torch.autograd.no_grad():
- records = []
+ record_c_quizzes, record_agreements = [], []
last_log = -1
start_time = time.perf_counter()
- while bag_len(records) < wanted_nb:
+ while nb_c_quizzes_per_model.min() < wanted_nb:
model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
generator_id = model.id
# to_keep = c_quiz_criterion_two_good(probas)
nb_disagreements = []
- for model in models:
+ for i, model in enumerate(models):
+ assert i == model.id # a bit of paranoia
model = copy_for_inference(model)
nb_disagreements.append(
model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
nb_disagreements = torch.cat(nb_disagreements, dim=1)
v = nb_disagreements.sort(dim=1).values
- to_keep = (v[:, 1] == 0) & (v[:, -1] > 3)
+ to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4)
q = c_quizzes[to_keep]
if q.size(0) > 0:
- records.append(q)
+ record_c_quizzes.append(q)
+ a = (nb_disagreements == 0)[to_keep]
+ record_agreements.append(a)
+ nb_c_quizzes_per_model += a.long().sum(dim=0)
duration = time.perf_counter() - start_time
- nb_generated = bag_len(records)
+ nb_generated = nb_c_quizzes_per_model.min().item()
if last_log < 0 or duration > last_log + 5:
last_log = duration
e = "???"
log_string(
- f"nb_generated {bag_len(records)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
+ f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
)
duration = time.perf_counter() - start_time
log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
- c_quizzes = torch.cat(records, dim=0).unique(dim=0)
+ c_quizzes = torch.cat(record_c_quizzes, dim=0)
+ agreements = torch.cat(record_agreements, dim=0)
subset_c_quizzes = c_quizzes[:nb_to_save]
+ # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # for model in models:
+ # model = copy_for_inference(model)
+ # prediction = model_ae_argmax_predictions(model, subset_c_quizzes)
+ # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png"
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=prediction,
+ # nrow=8,
+ # )
+ # log_string(f"wrote {filename}")
+ # exit(0)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
filename = f"culture_c_quiz_{n_epoch:04d}.png"
# c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
log_string(f"wrote {filename}")
- return c_quizzes
+ return c_quizzes, agreements
def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
time_c_quizzes = int(time.perf_counter() - start_time)
- c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0)
+ c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
+ agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
+
+ print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}")
# --------------------------------------------------------------------
for gpu, model in zip(gpus, weakest_models):
log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+ if c_quizzes is None:
+ c_quizzes_for_this_model = None
+ else:
+ c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
t = threading.Thread(
target=one_ae_epoch,
daemon=True,
- args=(model, quiz_machine, n_epoch, c_quizzes, gpu),
+ args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
)
threads.append(t)