def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, data_structures=data_structures
+ nb, data_structures=data_structures
)
src = zip(
targets = input
- mask_noise = (mask_generate != 0) & (
+ mask_diffusion_noise = (mask_generate == 1) & (
torch.rand(mask_generate.size(), device=mask_generate.device)
<= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
)
- mask_noise = mask_noise.long()
+ mask_diffusion_noise = mask_diffusion_noise.long()
- input = (1 - mask_noise) * input + mask_noise * torch.randint(
+ input = (
+ 1 - mask_diffusion_noise
+ ) * input + mask_diffusion_noise * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
):
targets = input
- mask_noise = (mask_generate != 0) & (
+ mask_diffusion_noise = (mask_generate == 1) & (
torch.rand(mask_generate.size(), device=mask_generate.device)
<= torch.rand(
(mask_generate.size(0), 1), device=mask_generate.device
)
)
- mask_noise = mask_noise.long()
+ mask_diffusion_noise = mask_diffusion_noise.long()
- input = (1 - mask_noise) * input + mask_noise * torch.randint(
+ input = (
+ 1 - mask_diffusion_noise
+ ) * input + mask_diffusion_noise * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
+ # -------------------------------------------
+ # Test generation
+
input, mask_generate, mask_loss = next(
ae_batches(quiz_machine, 128, data_structures, local_device)
)
targets = input
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
-
pred_result = None
-
- mask_noise = (mask_generate != 0) & (
- torch.rand(mask_generate.size(), device=mask_generate.device)
- <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
- )
-
- mask_noise = mask_noise.long()
-
- result = (1 - mask_noise) * input + mask_noise * torch.randint(
+ result = (1 - mask_generate) * input + mask_generate * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
logits = model(mygpt.BracketedSequence(result)).x
dist = torch.distributions.categorical.Categorical(logits=logits)
pred_result = result.clone()
- result[i] = (1 - mask_generate) * input + (
+ result[i] = (1 - mask_generate[i]) * input + (
mask_generate * dist.sample()[i]
)
changed = (pred_result == result).long().min(dim=1).values == 0