model.train()
nb_train_samples, acc_train_loss = 0, 0.0
- data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
- ]
-
full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, data_structures=data_structures
+ args.nb_train_samples
)
src = zip(
- full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ full_input.split(args.batch_size),
+ full_mask_generate.split(args.batch_size),
+ full_mask_loss.split(args.batch_size),
)
- for input, mask_loss in tqdm.tqdm(
+ for input, mask_generate, mask_loss in tqdm.tqdm(
src,
dynamic_ncols=True,
desc="training",
total=full_input.size(0) // args.batch_size,
):
input = input.to(local_device)
+ mask_generate = mask_generate.to(local_device)
mask_loss = mask_loss.to(local_device)
if nb_train_samples % args.batch_size == 0:
nb_test_samples, acc_test_loss = 0, 0.0
full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- args.nb_test_samples, data_structures=data_structures
+ args.nb_test_samples
)
src = zip(
- full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ full_input.split(args.batch_size),
+ full_mask_generate.split(args.batch_size),
+ full_mask_loss.split(args.batch_size),
)
- for input, mask_loss in tqdm.tqdm(
+ for input, mask_generate, mask_loss in tqdm.tqdm(
src,
dynamic_ncols=True,
desc="testing",
total=full_input.size(0) // args.batch_size,
):
input = input.to(local_device)
+ mask_generate = mask_generate.to(local_device)
mask_loss = mask_loss.to(local_device)
+
targets = input
input = (mask_generate == 0).long() * input
output = model(mygpt.BracketedSequence(input)).x
log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
- input, mask_generate, mask_loss = quiz_machine.data_input(
- 128, data_structures=data_structures
- )
+ input, mask_generate, mask_loss = quiz_machine.data_input(128)
input = input.to(local_device)
+ mask_generate = mask_generate.to(local_device)
mask_loss = mask_loss.to(local_device)
targets = input
input = (mask_generate == 0).long() * input
result[:, 1 * L] = input[:, 1 * L]
result[:, 2 * L] = input[:, 2 * L]
result[:, 3 * L] = input[:, 3 * L]
+ correct = (result == targets).min(dim=1).values.long()
+ predicted_parts = input.new(input.size(0), 4)
+
+ nb = 0
+
+ # We consider all the configurations that we train for
+ for struct, quad_generate, _, _ in quiz_machine.test_structures:
+ i = quiz_machine.problem.indices_select(quizzes=input, struct=struct)
+ nb += i.long().sum()
+
+ predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
+ None, :
+ ]
+ solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
+ correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
+
+ assert nb == input.size(0)
+
+ nb_correct = (correct == 1).long().sum()
+ nb_total = (correct != 0).long().sum()
+
+ self.logger(
+ f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+ )
+
+ correct_parts = predicted_parts * correct[:, None]
+
filename = f"prediction_ae_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
quizzes=result,
+ predicted_parts=predicted_parts,
+ correct_parts=correct_parts,
)
log_string(f"wrote {filename}")