return (-loss).exp()
-nb_diffusion_iterations = 25
+def model_ae_argmax_nb_disagreements(model, input):
+ record = []
+
+ for q in input.split(args.batch_size):
+ nb_disagreements = 0
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ )
+ targets, logits = targets_and_prediction(
+ model, q, mask_generate, prompt_noise=args.prompt_noise
+ )
+
+ predicted = logits.argmax(dim=-1)
+
+ nb_disagreements = nb_disagreements + (
+ mask_generate * predicted != mask_generate * targets
+ ).long().sum(dim=1)
+
+ record.append(nb_disagreements)
+
+ return torch.cat(record, dim=0)
def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
def generate_ae_c_quizzes(models, nb, local_device=main_device):
- criteria = [
- c_quiz_criterion_few_good_one_bad,
- # c_quiz_criterion_only_one,
- # c_quiz_criterion_one_good_one_bad,
- # c_quiz_criterion_one_good_no_very_bad,
- # c_quiz_criterion_diff,
- # c_quiz_criterion_diff2,
- # c_quiz_criterion_two_good,
- # c_quiz_criterion_some,
- ]
-
# To be thread-safe we must make copies
- models = [copy.deepcopy(model).to(local_device) for model in models]
-
quad_order = ("A", "f_A", "B", "f_B")
template = quiz_machine.problem.create_empty_quizzes(
quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
)
- duration_max = 4 * 3600
+ def copy_for_inference(model):
+ return copy.deepcopy(model).to(local_device).eval()
wanted_nb = nb
nb_to_save = 256
with torch.autograd.no_grad():
- records = [[] for _ in criteria]
+ records = []
last_log = -1
start_time = time.perf_counter()
- while (
- time.perf_counter() < start_time + duration_max
- and min([bag_len(bag) for bag in records]) < wanted_nb
- ):
- model = models[torch.randint(len(models), (1,)).item()]
+ while bag_len(records) < wanted_nb:
+ model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
+
c_quizzes = ae_generate(model, template, mask_generate)
to_keep = quiz_machine.problem.trivial(c_quizzes) == False
c_quizzes = c_quizzes[to_keep]
if c_quizzes.size(0) > 0:
- probas = torch.cat(
- [
- model_ae_proba_solutions(model, c_quizzes)[:, None]
- for model in models
- ],
- dim=1,
- )
+ # p = [
+ # model_ae_proba_solutions(model, c_quizzes)[:, None]
+ # for model in models
+ # ]
+
+ # probas = torch.cat(p, dim=1)
+ # to_keep = c_quiz_criterion_two_good(probas)
- for c, r in zip(criteria, records):
- q = c_quizzes[c(probas)]
- if q.size(0) > 0:
- r.append(q)
+ nb_disagreements = []
+ for model in models:
+ model = copy_for_inference(model)
+ nb_disagreements.append(
+ model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
+ :, None
+ ]
+ )
+ nb_disagreements = torch.cat(nb_disagreements, dim=1)
+
+ v = nb_disagreements.sort(dim=1).values
+ to_keep = (v[:, 1] == 0) & (v[:, -1] > 3)
+
+ q = c_quizzes[to_keep]
+
+ if q.size(0) > 0:
+ records.append(q)
duration = time.perf_counter() - start_time
- nb_generated = min([bag_len(bag) for bag in records])
+ nb_generated = bag_len(records)
- if last_log < 0 or duration > last_log + 60:
+ if last_log < 0 or duration > last_log + 5:
last_log = duration
if nb_generated > 0:
if nb_generated < wanted_nb:
else:
e = "???"
- bl = [bag_len(bag) for bag in records]
log_string(
- f"bag_len {bl} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
+ f"nb_generated {bag_len(records)} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
)
duration = time.perf_counter() - start_time
log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
- for n, u in enumerate(records):
- quizzes = torch.cat(u, dim=0)[:nb_to_save]
- filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png"
+ c_quizzes = torch.cat(records, dim=0).unique(dim=0)
- # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
+ subset_c_quizzes = c_quizzes[:nb_to_save]
- l = [model_ae_proba_solutions(model, quizzes) for model in models]
- probas = torch.cat([x[:, None] for x in l], dim=1)
- comments = []
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
- for l in probas:
- comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+ # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=quizzes,
- # predicted_parts=predicted_parts,
- # correct_parts=correct_parts,
- comments=comments,
- delta=True,
- nrow=8,
- )
+ l = [
+ model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
+ for model in models
+ ]
+ probas = torch.cat([x[:, None] for x in l], dim=1)
+ comments = []
- log_string(f"wrote {filename}")
+ for l in probas:
+ comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
- a = [torch.cat(u, dim=0) for u in records]
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=subset_c_quizzes,
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ comments=comments,
+ delta=True,
+ nrow=8,
+ )
+
+ log_string(f"wrote {filename}")
- return torch.cat(a, dim=0).unique(dim=0)
+ return c_quizzes
def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):