X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=tasks.py;h=8680ba14e719852dd1f0cefe4bb3674195edf3bc;hb=31ed8a54992e7701eebd1c3d49bfe8dc20aa65e3;hp=77493a81784c969a61ba6fd27416d0e79221ea6a;hpb=61c98647a2d708c8f2c5f0d25bcf05df92e1233f;p=culture.git diff --git a/tasks.py b/tasks.py index 77493a8..8680ba1 100755 --- a/tasks.py +++ b/tasks.py @@ -220,7 +220,7 @@ class World(Task): self.save_image( result[:96], result_dir, - f"world_result_{n_epoch:04d}_{model.id:02d}.png", + f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, ) @@ -273,6 +273,29 @@ class World(Task): device=self.device, ) - nb_correct += (new_quizzes == result).long().min(dim=-1).values + l = self.height * self.width + direction = new_quizzes[:, l : l + 1] + direction = world.token_forward * ( + direction == world.token_backward + ) + world.token_backward * (direction == world.token_forward) + inverted_quizzes = torch.cat( + [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1 + ) + + inverted_result = inverted_quizzes.clone() + + masked_inplace_autoregression( + m, + self.batch_size, + inverted_result, + ar_mask, + deterministic_synthesis=True, + progress_bar_desc="solving reverse quizzes", + device=self.device, + ) + + nb_correct += (new_quizzes == result).long().min(dim=-1).values * ( + inverted_quizzes == inverted_result + ).long().min(dim=-1).values return new_quizzes, nb_correct