######################################################################
+def add_hints_(imt_set):
+ input, masks, targets = imt_set
+ h = torch.rand(masks.size(), device=masks.device) - masks
+ t = h.sort(dim=1).values[:, args.nb_hints, None]
+ mask_hints = (h < t).long()
+ masks[...] = (1 - mask_hints) * masks
+ input[...] = (1 - mask_hints) * input + mask_hints * targets
+
+
def add_hints(masks, fraction_with_hints):
if fraction_with_hints > 0:
- h = torch.rand(masks.size(), device=masks.device) * masks
- mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
- v = torch.rand(masks.size(0), device=masks.device)[:, None]
- mask_hints = mask_hints * (v < fraction_with_hints).long()
+ h = torch.rand(masks.size(), device=masks.device) - masks
+ t = h.sort(dim=1).values[:, args.nb_hints, None]
+ mask_hints = (h < t).long()
return (1 - mask_hints) * masks
else:
return masks
# IMT for input / masks / target
-def batch_prediction_imt(input, fraction_with_hints=0.0):
+def batch_for_prediction_imt(input):
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
masks.view(nb, 4, -1)[...] = u[:, :, None]
- masks = add_hints(masks, fraction_with_hints)
targets = input
input = (1 - masks) * targets
return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
-def predict(model, imt_set, local_device=main_device, desc="predict"):
+def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
model.eval().to(local_device)
record = []
return torch.cat(record)
-def predict_full(model, input, fraction_with_hints, local_device=main_device):
+def predict_full(model, input, local_device=main_device):
input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
masks.view(nb, 4, -1)[...] = u[:, :, None]
- masks_with_hints = add_hints(masks, fraction_with_hints)
targets = input
- input = (1 - masks_with_hints) * targets
- imt_set = torch.cat(
- [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
- )
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- result = predict(model, imt_set, local_device=local_device, desc=None)
+ result = ae_predict(model, imt_set, local_device=local_device, desc=None)
result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
return result
######################################################################
-def batch_generation_imt(input):
+def batch_for_generation_imt(input):
nb = input.size(0)
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.diffusion_nb_iterations, device=input.device
return y
-def generate(model, nb, local_device=main_device, desc="generate"):
+def ae_generate(model, nb, local_device=main_device, desc="generate"):
model.eval().to(local_device)
all_input = quiz_machine.pure_noise(nb, local_device)
imt_set = torch.cat(
[
- batch_prediction_imt(q1, fraction_with_hints=0.5),
- batch_generation_imt(q2),
+ batch_for_prediction_imt(q1),
+ batch_for_generation_imt(q2),
]
)
args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
)
- # Save some images of the prediction results (one grid at random)
+ # Save some images of the prediction results
quizzes = quiz_machine.quiz_set(
args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
)
- imt_set = batch_prediction_imt(quizzes.to(local_device))
- result = predict(model, imt_set, local_device=local_device).to("cpu")
+ imt_set = batch_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
correct = (quizzes == result).min(dim=1).values.long()
# Save some images of the ex nihilo generation of the four grids
- result = generate(model, 150, local_device=local_device).to("cpu")
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
f"culture_generation_{n_epoch}_{model.id}.png",
######################################################################
-def generate_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
record = []
nb_validated = 0
start_time = time.perf_counter()
last_log = -1
- while nb_validated < nb:
+ while nb_validated < nb_to_generate:
# Generate new quizzes
model = models[torch.randint(len(models), (1,)).item()]
model = copy.deepcopy(model).to(local_device).eval()
generator_id = model.id
- c_quizzes = generate(
+ c_quizzes = ae_generate(
model=model,
nb=args.physical_batch_size,
local_device=local_device,
if last_log < 0 or duration > last_log + 10:
last_log = duration
if nb_validated > 0:
- if nb_validated < nb:
- d = (nb - nb_validated) * duration / nb_validated
+ if nb_validated < nb_to_generate:
+ d = (nb_to_generate - nb_validated) * duration / nb_validated
e = (
datetime.datetime.now() + datetime.timedelta(seconds=d)
).strftime("%a %H:%M")
duration = time.perf_counter() - start_time
- log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h")
+ log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
return torch.cat(record).to("cpu")
mask_generate = quiz_machine.make_quiz_mask(
quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
)
- result = generate(
+ result = ae_generate(
model,
(1 - mask_generate) * quizzes,
mask_generate,
######################################################################
-last_n_epoch_c_quizzes = 0
-
c_quizzes = None
time_c_quizzes = 0
if c_quizzes is None:
save_models(models, "naive")
- last_n_epoch_c_quizzes = n_epoch
nb_gpus = len(gpus)
nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus