parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--check", action="store_true", default=False)
+
######################################################################
args = parser.parse_args()
######################################################################
+if args.test:
+ args.nb_train_samples = 1000
+ args.nb_test_samples = 25
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
######################################################################
accuracy_to_make_quizzes = 0.975
+nb_new_quizzes_for_train = 1000
+nb_new_quizzes_for_test = 100
+
+if args.test:
+ accuracy_to_make_quizzes = 0.0
+ nb_new_quizzes_for_train = 10
+ nb_new_quizzes_for_test = 10
for n_epoch in range(args.nb_epochs):
# select the model with lowest accuracy
model,
other_models,
task,
- nb_for_train=1000,
- nb_for_test=100,
+ nb_for_train=nb_new_quizzes_for_train,
+ nb_for_test=nb_new_quizzes_for_test,
)
device=self.device,
)
- nb_correct += (new_quizzes == result).long().min(dim=-1).values
+ l = self.height * self.width
+ direction = new_quizzes[:, l : l + 1]
+ direction = world.token_forward * (
+ direction == world.token_backward
+ ) + world.token_backward * (direction == world.token_forward)
+ inverted_quizzes = torch.cat(
+ [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
+ )
+
+ inverted_result = inverted_quizzes.clone()
+
+ masked_inplace_autoregression(
+ m,
+ self.batch_size,
+ inverted_result,
+ ar_mask,
+ deterministic_synthesis=True,
+ progress_bar_desc="solving reverse quizzes",
+ device=self.device,
+ )
+
+ nb_correct += (
+ (
+ (new_quizzes == result).long()
+ * (inverted_quizzes, inverted_result).long()
+ )
+ .min(dim=-1)
+ .values
+ )
return new_quizzes, nb_correct