return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
+def c_quiz_criterion_some(probas):
+ return ((probas >= 0.8).long().sum(dim=1) >= 1) & (
+ (probas <= 0.2).long().sum(dim=1) >= 1
+ )
+
+
def generate_ae_c_quizzes(models, local_device=main_device):
criteria = [
c_quiz_criterion_one_good_one_bad,
c_quiz_criterion_diff,
c_quiz_criterion_two_certains,
+ c_quiz_criterion_some,
]
for m in models:
quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
)
- duration_max = 600 # 3 * 3600
+ duration_max = 3600
+
+ wanted_nb = 512
with torch.autograd.no_grad():
records = [[] for _ in criteria]
while (
time.perf_counter() < start_time + duration_max
- and min([bag_len(bag) for bag in records]) < 128
+ and min([bag_len(bag) for bag in records]) < wanted_nb
):
bl = [bag_len(bag) for bag in records]
log_string(f"bag_len {bl}")
model = models[torch.randint(len(models), (1,)).item()]
result = ae_generate(model, template, mask_generate, noise_proba)
- probas = torch.cat(
- [model_ae_proba_solutions(model, result)[:, None] for model in models],
- dim=1,
- )
+ to_keep = quiz_machine.problem.trivial(result) == False
+ result = result[to_keep]
- for c, r in zip(criteria, records):
- q = result[c(probas)]
- if q.size(0) > 0:
- r.append(q)
+ if result.size(0) > 0:
+ probas = torch.cat(
+ [
+ model_ae_proba_solutions(model, result)[:, None]
+ for model in models
+ ],
+ dim=1,
+ )
- for n, u in enumerate(records):
- quizzes = torch.cat(u, dim=0)[:128]
- filename = f"culture_{n_epoch:04d}_{n:02d}.png"
+ for c, r in zip(criteria, records):
+ q = result[c(probas)]
+ if q.size(0) > 0:
+ r.append(q)
- # result, predicted_parts, correct_parts = bag_to_tensors(record)
+ duration = time.perf_counter() - start_time
- # l = [model_ae_proba_solutions(model, result) for model in models]
- # probas = torch.cat([x[:, None] for x in l], dim=1)
- # comments = []
+ log_string(
+ f"generate_c_quizz_generation_speed {int(3600 * wanted_nb / duration)}/h"
+ )
- # for l in probas:
- # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+ for n, u in enumerate(records):
+ quizzes = torch.cat(u, dim=0)[:wanted_nb]
+ filename = f"culture_c_{n_epoch:04d}_{n:02d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result,
- # predicted_parts=predicted_parts,
- # correct_parts=correct_parts,
- # comments=comments,
- )
+ # result, predicted_parts, correct_parts = bag_to_tensors(record)
- log_string(f"wrote {filename}")
+ l = [model_ae_proba_solutions(model, quizzes) for model in models]
+ probas = torch.cat([x[:, None] for x in l], dim=1)
+ comments = []
+
+ for l in probas:
+ comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=quizzes,
+ # predicted_parts=predicted_parts,
+ # correct_parts=correct_parts,
+ comments=comments,
+ nrow=8,
+ )
+
+ log_string(f"wrote {filename}")
######################################################################