)
-class MyAttentionVAE(nn.Module):
+class MyAttentionAE(nn.Module):
def __init__(
self,
vocabulary_size,
def test_ae(local_device=main_device):
- model = MyAttentionVAE(
+ model = MyAttentionAE(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
# -------------------------------------------
# Test generation
- input, mask_generate, mask_loss = next(
- ae_batches(quiz_machine, 128, data_structures, local_device)
- )
+ for ns, s in enumerate(data_structures):
+ quad_order, quad_generate, _, _ = s
- targets = input
+ input, mask_generate, mask_loss = next(
+ ae_batches(quiz_machine, 128, [s], local_device)
+ )
- input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
+ targets = input
- result = (1 - mask_generate) * input + mask_generate * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
+ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- not_converged = torch.full((result.size(0),), True, device=result.device)
+ result = (1 - mask_generate) * input + mask_generate * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
- nb_it = 0
+ not_converged = torch.full(
+ (result.size(0),), True, device=result.device
+ )
- while True:
- logits = model(mygpt.BracketedSequence(result)).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- pred_result = result.clone()
- update = (1 - mask_generate) * input + mask_generate * dist.sample()
- result[not_converged] = update[not_converged]
- not_converged = (pred_result != result).max(dim=1).values
- nb_it += 1
- print("DEBUG", nb_it, not_converged.long().sum().item())
- if not not_converged.any() or nb_it > 100:
- break
+ nb_it = 0
- correct = (result == targets).min(dim=1).values.long()
- predicted_parts = input.new(input.size(0), 4)
+ while True:
+ logits = model(mygpt.BracketedSequence(result)).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ pred_result = result.clone()
+ update = (1 - mask_generate) * input + mask_generate * dist.sample()
+ result[not_converged] = update[not_converged]
+ not_converged = (pred_result != result).max(dim=1).values
+ nb_it += 1
+ print("DEBUG", nb_it, not_converged.long().sum().item())
+ if not not_converged.any() or nb_it > 100:
+ break
- nb = 0
+ correct = (result == targets).min(dim=1).values.long()
+ predicted_parts = input.new(input.size(0), 4)
- # We consider all the configurations that we train for
- for quad_order, quad_generate, _, _ in quiz_machine.test_structures:
- i = quiz_machine.problem.indices_select(
- quizzes=input, quad_order=quad_order
- )
- nb += i.long().sum()
+ nb = 0
- predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
+ predicted_parts = torch.tensor(quad_generate, device=result.device)[
None, :
]
- solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
- correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
-
- assert nb == input.size(0)
+ solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
+ correct = (2 * correct - 1) * (solution_is_deterministic).long()
- nb_correct = (correct == 1).long().sum()
- nb_total = (correct != 0).long().sum()
+ nb_correct = (correct == 1).long().sum()
+ nb_total = (correct != 0).long().sum()
- log_string(
- f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
- )
+ log_string(
+ f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+ )
- correct_parts = predicted_parts * correct[:, None]
+ correct_parts = predicted_parts * correct[:, None]
+ predicted_parts = predicted_parts.expand_as(correct_parts)
- filename = f"prediction_ae_{n_epoch:04d}.png"
+ filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result,
- predicted_parts=predicted_parts,
- correct_parts=correct_parts,
- )
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result,
+ predicted_parts=predicted_parts,
+ correct_parts=correct_parts,
+ )
- log_string(f"wrote {filename}")
+ log_string(f"wrote {filename}")
if args.test == "ae":