X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=0345bd01279063c2cc4edd0356cb903ff29faf15;hb=6917d3d52a4b473d31121a471ab98fa114bdb1a6;hp=622cd567ae5d1d845dfe1b0b186fb1851ae0c80e;hpb=17267a244c31be85db250706fead811f20158810;p=culture.git diff --git a/tasks.py b/tasks.py index 622cd56..0345bd0 100755 --- a/tasks.py +++ b/tasks.py @@ -2244,7 +2244,9 @@ class World(Task): input[:nb_kept] = input[-nb_kept:].clone() input[nb_kept:] = new_quizzes - def create_new_quizzes(self, n_epoch, result_dir, logger, nb, model, nb_runs): + def create_new_quizzes( + self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs + ): new_quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) @@ -2274,15 +2276,18 @@ class World(Task): .expand_as(result) ) - masked_inplace_autoregression( - model, - self.batch_size, - result, - ar_mask, - deterministic_synthesis=False, - progress_bar_desc=None, - device=self.device, - ) + dispatch = torch.randint(len(other_models), (result.size(0),)) + + for n, m in enumerate(other_models): + masked_inplace_autoregression( + m, + self.batch_size, + result[dispatch == n], + ar_mask[dispatch == n], + deterministic_synthesis=False, + progress_bar_desc=None, + device=self.device, + ) nb_correct = ( (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)