def create_quizzes(
+ model,
other_models,
task,
nb_for_train=1000,
result_dir=args.result_dir,
logger=log_string,
nb=4 * (nb_for_train + nb_for_test),
- models=other_models,
+ model=model,
+ other_models=other_models,
nb_runs=nb_runs,
)
input[:nb_kept] = input[-nb_kept:].clone()
input[nb_kept:] = new_quizzes
- def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs):
+ def create_new_quizzes(
+ self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs
+ ):
new_quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
.expand_as(result)
)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis=False,
- progress_bar_desc=None,
- device=self.device,
- )
+ dispatch = torch.randint(len(other_models), (result.size(0),))
+
+ for n, m in enumerate(other_models):
+ masked_inplace_autoregression(
+ m,
+ self.batch_size,
+ result[dispatch == n],
+ ar_mask[dispatch == n],
+ deterministic_synthesis=False,
+ progress_bar_desc=None,
+ device=self.device,
+ )
nb_correct = (
(input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)