From aa89fd9d6a29bd1b317152d6e66c9c1d2bc6203d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 23 Aug 2024 07:23:50 +0200 Subject: [PATCH] Update. --- main.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 1999bac..0fe33f6 100755 --- a/main.py +++ b/main.py @@ -823,7 +823,7 @@ class MyAttentionVAE(nn.Module): def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, data_structures=data_structures + nb, data_structures=data_structures ) src = zip( @@ -894,14 +894,16 @@ def test_ae(local_device=main_device): targets = input - mask_noise = (mask_generate != 0) & ( + mask_diffusion_noise = (mask_generate == 1) & ( torch.rand(mask_generate.size(), device=mask_generate.device) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) ) - mask_noise = mask_noise.long() + mask_diffusion_noise = mask_diffusion_noise.long() - input = (1 - mask_noise) * input + mask_noise * torch.randint( + input = ( + 1 - mask_diffusion_noise + ) * input + mask_diffusion_noise * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -935,16 +937,18 @@ def test_ae(local_device=main_device): ): targets = input - mask_noise = (mask_generate != 0) & ( + mask_diffusion_noise = (mask_generate == 1) & ( torch.rand(mask_generate.size(), device=mask_generate.device) <= torch.rand( (mask_generate.size(0), 1), device=mask_generate.device ) ) - mask_noise = mask_noise.long() + mask_diffusion_noise = mask_diffusion_noise.long() - input = (1 - mask_noise) * input + mask_noise * torch.randint( + input = ( + 1 - mask_diffusion_noise + ) * input + mask_diffusion_noise * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -955,6 +959,9 @@ def test_ae(local_device=main_device): log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") + # ------------------------------------------- + # Test generation + input, mask_generate, mask_loss = next( ae_batches(quiz_machine, 128, data_structures, local_device) ) @@ -962,17 +969,8 @@ def test_ae(local_device=main_device): targets = input input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA - pred_result = None - - mask_noise = (mask_generate != 0) & ( - torch.rand(mask_generate.size(), device=mask_generate.device) - <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) - ) - - mask_noise = mask_noise.long() - - result = (1 - mask_noise) * input + mask_noise * torch.randint( + result = (1 - mask_generate) * input + mask_generate * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -984,7 +982,7 @@ def test_ae(local_device=main_device): logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits) pred_result = result.clone() - result[i] = (1 - mask_generate) * input + ( + result[i] = (1 - mask_generate[i]) * input + ( mask_generate * dist.sample()[i] ) changed = (pred_result == result).long().min(dim=1).values == 0 -- 2.39.5