targets = input
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- pred_result = None
result = (1 - mask_generate) * input + mask_generate * torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
- i = torch.full((result.size(0),), True, device=result.device)
+ not_converged = torch.full((result.size(0),), True, device=result.device)
nb_it = 0
logits = model(mygpt.BracketedSequence(result)).x
dist = torch.distributions.categorical.Categorical(logits=logits)
pred_result = result.clone()
- result[i] = (1 - mask_generate[i]) * input + (
- mask_generate * dist.sample()[i]
- )
- changed = (pred_result == result).long().min(dim=1).values == 0
- i = i & changed
+ result[not_converged] = (
+ (1 - mask_generate) * input + mask_generate * dist.sample()
+ )[not_converged]
+ not_converged = (pred_result == result).long().min(dim=1).values == 0
nb_it += 1
print("DEBUG", nb_it, i.long().sum().item())
if not i.any() or nb_it > 100: