mask_diffusion_noise = mask_diffusion_noise.long()
input[...] = (
- 1 - mask_generate
- ) * input + mask_generate * mask_diffusion_noise * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
+ mask_generate
+ * mask_diffusion_noise
+ * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ + (1 - mask_generate * mask_diffusion_noise) * input
)
+
else:
model.eval()
for it in range(torch.randint(5, (1,)).item()):
torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
)
).x
+
+ # for filename, quizzes in [
+ # ("targets.png", targets),
+ # ("input.png", input),
+ # ("mask_generate.png", mask_generate),
+ # ("mask_loss.png", mask_loss),
+ # ]:
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=quizzes,
+ # )
+ # time.sleep(10000)
+
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)