targets = input
- input = (mask_generate == 0).long() * input + (
- 1 - (mask_generate == 0).long()
- ) * torch.randint(
+ mask_noise = (mask_generate != 0) & (
+ torch.rand(mask_generate.size(), device=mask_generate.device)
+ <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+ )
+
+ mask_noise = mask_noise.long()
+
+ input = (1 - mask_noise) * input + mask_noise * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
+ L = input.size(1) // 4
+
+ input[:, 0 * L] = targets[:, 0 * L]
+ input[:, 1 * L] = targets[:, 1 * L]
+ input[:, 2 * L] = targets[:, 2 * L]
+ input[:, 3 * L] = targets[:, 3 * L]
+
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_train_loss += loss.item() * input.size(0)
targets = input
- input = (mask_generate == 0).long() * input + (
- 1 - (mask_generate == 0).long()
- ) * torch.randint(
+ mask_noise = (mask_generate != 0) & (
+ torch.rand(mask_generate.size(), device=mask_generate.device)
+ <= torch.rand(
+ (mask_generate.size(0), 1), device=mask_generate.device
+ )
+ )
+
+ mask_noise = mask_noise.long()
+
+ input = (1 - mask_noise) * input + mask_noise * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
+ L = input.size(1) // 4
+
+ input[:, 0 * L] = targets[:, 0 * L]
+ input[:, 1 * L] = targets[:, 1 * L]
+ input[:, 2 * L] = targets[:, 2 * L]
+ input[:, 3 * L] = targets[:, 3 * L]
+
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_test_loss += loss.item() * input.size(0)
pred_result = None
frozzen = None
- result = (mask_generate == 0).long() * input + (
- 1 - (mask_generate == 0).long()
- ) * torch.randint(
+ mask_noise = (mask_generate != 0) & (
+ torch.rand(mask_generate.size(), device=mask_generate.device)
+ <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+ )
+
+ mask_noise = mask_noise.long()
+
+ result = (1 - mask_noise) * input + mask_noise * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
+ L = input.size(1) // 4
+
+ result[:, 0 * L] = input[:, 0 * L]
+ result[:, 1 * L] = input[:, 1 * L]
+ result[:, 2 * L] = input[:, 2 * L]
+ result[:, 3 * L] = input[:, 3 * L]
+
i = torch.full((result.size(0),), True, device=result.device)
nb_it = 0
- L = input.size(1) // 4
-
while True:
logits = model(mygpt.BracketedSequence(result)).x
dist = torch.distributions.categorical.Categorical(logits=logits)