targets = input
- mask_diffusion_noise = (mask_generate == 1) & (
- torch.rand(mask_generate.size(), device=mask_generate.device)
- <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
- )
+ # mask_diffusion_noise = (mask_generate == 1) & (
+ # torch.rand(mask_generate.size(), device=mask_generate.device)
+ # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+ # )
- mask_diffusion_noise = mask_diffusion_noise.long()
+ # mask_diffusion_noise = mask_diffusion_noise.long()
- input = (
- 1 - mask_diffusion_noise
- ) * input + mask_diffusion_noise * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
+ # input = (
+ # 1 - mask_diffusion_noise
+ # ) * input + mask_diffusion_noise * torch.randint(
+ # quiz_machine.problem.nb_colors, input.size(), device=input.device
+ # )
+
+ # ------------------------------
+ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
+ model.eval()
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(mygpt.BracketedSequence(input)).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ input = (1 - mask_generate) * input + mask_generate * dist.sample()
+ model.train()
+ # -----------------------------
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
):
targets = input
- mask_diffusion_noise = (mask_generate == 1) & (
- torch.rand(mask_generate.size(), device=mask_generate.device)
- <= torch.rand(
- (mask_generate.size(0), 1), device=mask_generate.device
- )
- )
+ # mask_diffusion_noise = (mask_generate == 1) & (
+ # torch.rand(mask_generate.size(), device=mask_generate.device)
+ # <= torch.rand(
+ # (mask_generate.size(0), 1), device=mask_generate.device
+ # )
+ # )
- mask_diffusion_noise = mask_diffusion_noise.long()
+ # mask_diffusion_noise = mask_diffusion_noise.long()
- input = (
- 1 - mask_diffusion_noise
- ) * input + mask_diffusion_noise * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
+ # input = (
+ # 1 - mask_diffusion_noise
+ # ) * input + mask_diffusion_noise * torch.randint(
+ # quiz_machine.problem.nb_colors, input.size(), device=input.device
+ # )
+
+ # ------------------------------
+ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
+
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(mygpt.BracketedSequence(input)).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ input = (1 - mask_generate) * input + mask_generate * dist.sample()
+ # -----------------------------
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- result = (1 - mask_generate) * input + mask_generate * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
+ result = (1 - mask_generate) * input
+
+ # + mask_generate * torch.randint(
+ # quiz_machine.problem.nb_colors, input.size(), device=input.device
+ # )
not_converged = torch.full(
(result.size(0),), True, device=result.device
)
- nb_it = 0
-
- while True:
- logits = model(mygpt.BracketedSequence(result)).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
+ for it in range(100):
pred_result = result.clone()
- update = (1 - mask_generate) * input + mask_generate * dist.sample()
- result[not_converged] = update[not_converged]
+ logits = model(mygpt.BracketedSequence(result[not_converged])).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ update = (1 - mask_generate[not_converged]) * input[
+ not_converged
+ ] + mask_generate[not_converged] * dist.sample()
+ result[not_converged] = update
not_converged = (pred_result != result).max(dim=1).values
- nb_it += 1
- print("DEBUG", nb_it, not_converged.long().sum().item())
- if not not_converged.any() or nb_it > 100:
+ if not not_converged.any():
+ log_string(f"diffusion_converged {it=}")
break
correct = (result == targets).min(dim=1).values.long()