model.optimizer.zero_grad()
targets = input
- input = (mask_generate == 0).long() * input
+
+ input = (mask_generate == 0).long() * input + (
+ 1 - (mask_generate == 0).long()
+ ) * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
mask_loss = mask_loss.to(local_device)
targets = input
- input = (mask_generate == 0).long() * input
+
+ input = (mask_generate == 0).long() * input + (
+ 1 - (mask_generate == 0).long()
+ ) * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_test_loss += loss.item() * input.size(0)
mask_generate = mask_generate.to(local_device)
mask_loss = mask_loss.to(local_device)
targets = input
- input = (mask_generate == 0).long() * input
- logits = model(mygpt.BracketedSequence(input)).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- result = dist.sample()
+
+ pred_result = None
+ frozzen = None
+
+ result = (mask_generate == 0).long() * input + (
+ 1 - (mask_generate == 0).long()
+ ) * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+
+ i = torch.full((result.size(0),), True, device=result.device)
+
+ nb_it = 0
+
L = input.size(1) // 4
- result[:, 0 * L] = input[:, 0 * L]
- result[:, 1 * L] = input[:, 1 * L]
- result[:, 2 * L] = input[:, 2 * L]
- result[:, 3 * L] = input[:, 3 * L]
+
+ while True:
+ logits = model(mygpt.BracketedSequence(result)).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ pred_result = result.clone()
+ result[i] = dist.sample()[i]
+ result[:, 0 * L] = input[:, 0 * L]
+ result[:, 1 * L] = input[:, 1 * L]
+ result[:, 2 * L] = input[:, 2 * L]
+ result[:, 3 * L] = input[:, 3 * L]
+ changed = (pred_result == result).long().min(dim=1).values == 0
+ i = i & changed
+ nb_it += 1
+ print("DEBUG", nb_it, i.long().sum().item())
+ if not i.any() or nb_it > 100:
+ break
+
correct = (result == targets).min(dim=1).values.long()
predicted_parts = input.new(input.size(0), 4)
nb_correct = (correct == 1).long().sum()
nb_total = (correct != 0).long().sum()
- self.logger(
- f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+ log_string(
+ f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
)
correct_parts = predicted_parts * correct[:, None]