From: François Fleuret Date: Thu, 22 Aug 2024 20:50:43 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=dad7b896973488739c4b1c833e7c0991e79e5bec;p=culture.git Update. --- diff --git a/main.py b/main.py index f28c10a..2a35209 100755 --- a/main.py +++ b/main.py @@ -872,12 +872,24 @@ def test_ae(local_device=main_device): targets = input - input = (mask_generate == 0).long() * input + ( - 1 - (mask_generate == 0).long() - ) * torch.randint( + mask_noise = (mask_generate != 0) & ( + torch.rand(mask_generate.size(), device=mask_generate.device) + <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) + ) + + mask_noise = mask_noise.long() + + input = (1 - mask_noise) * input + mask_noise * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) + L = input.size(1) // 4 + + input[:, 0 * L] = targets[:, 0 * L] + input[:, 1 * L] = targets[:, 1 * L] + input[:, 2 * L] = targets[:, 2 * L] + input[:, 3 * L] = targets[:, 3 * L] + output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_train_loss += loss.item() * input.size(0) @@ -921,12 +933,26 @@ def test_ae(local_device=main_device): targets = input - input = (mask_generate == 0).long() * input + ( - 1 - (mask_generate == 0).long() - ) * torch.randint( + mask_noise = (mask_generate != 0) & ( + torch.rand(mask_generate.size(), device=mask_generate.device) + <= torch.rand( + (mask_generate.size(0), 1), device=mask_generate.device + ) + ) + + mask_noise = mask_noise.long() + + input = (1 - mask_noise) * input + mask_noise * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) + L = input.size(1) // 4 + + input[:, 0 * L] = targets[:, 0 * L] + input[:, 1 * L] = targets[:, 1 * L] + input[:, 2 * L] = targets[:, 2 * L] + input[:, 3 * L] = targets[:, 3 * L] + output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_test_loss += loss.item() * input.size(0) @@ -943,18 +969,28 @@ def test_ae(local_device=main_device): pred_result = None frozzen = None - result = (mask_generate == 0).long() * input + ( - 1 - (mask_generate == 0).long() - ) * torch.randint( + mask_noise = (mask_generate != 0) & ( + torch.rand(mask_generate.size(), device=mask_generate.device) + <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) + ) + + mask_noise = mask_noise.long() + + result = (1 - mask_noise) * input + mask_noise * torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) + L = input.size(1) // 4 + + result[:, 0 * L] = input[:, 0 * L] + result[:, 1 * L] = input[:, 1 * L] + result[:, 2 * L] = input[:, 2 * L] + result[:, 3 * L] = input[:, 3 * L] + i = torch.full((result.size(0),), True, device=result.device) nb_it = 0 - L = input.size(1) // 4 - while True: logits = model(mygpt.BracketedSequence(result)).x dist = torch.distributions.categorical.Categorical(logits=logits)