)
-def mu_T_sampler(shape, device="cpu"):
- return torch.randint(quiz_machine.problem.nb_colors, shape, device=device)
-
-
diffuser = diffusion.Diffuser(
mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
)
######################################################################
+
+def add_hints(masks, fraction_with_hints):
+ if fraction_with_hints > 0:
+ h = torch.rand(masks.size(), device=masks.device) * masks
+ mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
+ v = torch.rand(masks.size(0), device=masks.device)[:, None]
+ mask_hints = mask_hints * (v < fraction_with_hints).long()
+ return (1 - mask_hints) * masks
+ else:
+ return masks
+
+
# IMT for input / masks / target
-def IMT_batch_prediction(input, proba_hints=0.0):
+def batch_prediction_imt(input, fraction_with_hints=0.0):
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
-
- if proba_hints > 0:
- h = torch.rand(input.size(), device=input.device) * masks
- mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
- v = torch.rand(nb, device=input.device)[:, None]
- mask_hints = mask_hints * (v < proba_hints).long()
- masks = (1 - mask_hints) * masks
-
+ masks = add_hints(masks, fraction_with_hints)
# noise = quiz_machine.problem.pure_noise(nb, input.device)
targets = input
input = (1 - masks) * targets # + masks * noise
return torch.cat(record)
+def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
+ boy_that_s_ugly = input.view(input.size(0), 4, -1)[:, :, 0].clone()
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+ masks_with_hints = add_hints(masks, fraction_with_hints)
+ targets = input
+ input = (1 - masks_with_hints) * targets
+ imt_set = torch.cat(
+ [input[:, None], masks_with_hints[:, None], targets[:, None]], dim=1
+ )
+
+ result = predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+ result.view(result.size(0), 4, -1)[:, :, 0] = boy_that_s_ugly
+
+ return result
+
+
######################################################################
-def IMT_batch_generation(input):
+def batch_generation_imt(input):
nb = input.size(0)
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.diffusion_nb_iterations, device=input.device
def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
- if train:
- label = "train"
- model.train().to(local_device)
- optimizer_to(model.optimizer, local_device)
- else:
- label = "test"
- model.eval().to(local_device)
-
- nb_samples, acc_loss = 0, 0.0
-
quizzes = quiz_machine.quiz_set(
args.nb_train_samples if train else args.nb_test_samples,
c_quizzes,
q1, q2 = quizzes.to(local_device).chunk(2)
imt_set = torch.cat(
- [IMT_batch_prediction(q1, proba_hints=0.5), IMT_batch_generation(q2)]
+ [
+ batch_prediction_imt(q1, fraction_with_hints=0.5),
+ batch_generation_imt(q2),
+ ]
)
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ else:
+ label = "test"
+ model.eval().to(local_device)
+
+ nb_samples, acc_loss = 0, 0.0
+
for imt in tqdm.tqdm(
imt_set.split(args.physical_batch_size),
dynamic_ncols=True,
one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!
+ quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(
+ local_device
+ )
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
+ )
+ result = predict_full(model, quizzes, local_device=local_device)
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
+ )
+ #!!!!!!!!!!!!!!!!!!!!!!!!!
+
# predict
quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
- imt_set = IMT_batch_prediction(quizzes.to(local_device))
+ imt_set = batch_prediction_imt(quizzes.to(local_device))
result = predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
######################################################################
-def quiz_validation(
+def quiz_validation_(
models,
c_quizzes,
local_device,