quiz_machine.problem.nb_colors, input.size(), device=input.device
)
- L = input.size(1) // 4
-
- input[:, 0 * L] = targets[:, 0 * L]
- input[:, 1 * L] = targets[:, 1 * L]
- input[:, 2 * L] = targets[:, 2 * L]
- input[:, 3 * L] = targets[:, 3 * L]
-
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_train_loss += loss.item() * input.size(0)
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
- L = input.size(1) // 4
-
- input[:, 0 * L] = targets[:, 0 * L]
- input[:, 1 * L] = targets[:, 1 * L]
- input[:, 2 * L] = targets[:, 2 * L]
- input[:, 3 * L] = targets[:, 3 * L]
-
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_test_loss += loss.item() * input.size(0)
targets = input
+ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
+
pred_result = None
- frozzen = None
mask_noise = (mask_generate != 0) & (
torch.rand(mask_generate.size(), device=mask_generate.device)
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
- L = input.size(1) // 4
-
- result[:, 0 * L] = input[:, 0 * L]
- result[:, 1 * L] = input[:, 1 * L]
- result[:, 2 * L] = input[:, 2 * L]
- result[:, 3 * L] = input[:, 3 * L]
-
i = torch.full((result.size(0),), True, device=result.device)
nb_it = 0
logits = model(mygpt.BracketedSequence(result)).x
dist = torch.distributions.categorical.Categorical(logits=logits)
pred_result = result.clone()
- result[i] = dist.sample()[i]
- result[:, 0 * L] = input[:, 0 * L]
- result[:, 1 * L] = input[:, 1 * L]
- result[:, 2 * L] = input[:, 2 * L]
- result[:, 3 * L] = input[:, 3 * L]
+ result[i] = (1 - mask_generate) * input + (
+ mask_generate * dist.sample()[i]
+ )
changed = (pred_result == result).long().min(dim=1).values == 0
i = i & changed
nb_it += 1