X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=f6d34a8d30f5439362521d374775f375c7c2b3ca;hb=f87a57354a1e575181e760fdaedbb2c2d5cf9fa0;hp=77493a81784c969a61ba6fd27416d0e79221ea6a;hpb=61c98647a2d708c8f2c5f0d25bcf05df92e1233f;p=culture.git diff --git a/tasks.py b/tasks.py index 77493a8..f6d34a8 100755 --- a/tasks.py +++ b/tasks.py @@ -14,9 +14,6 @@ from torch.nn import functional as F from mygpt import BracketedSequence -# from graph import save_attention_image -save_attention_image = None - ###################################################################### @@ -220,7 +217,7 @@ class World(Task): self.save_image( result[:96], result_dir, - f"world_result_{n_epoch:04d}_{model.id:02d}.png", + f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, ) @@ -252,7 +249,7 @@ class World(Task): new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new quizzes", + progress_bar_desc="creating quizzes", device=self.device, ) @@ -273,6 +270,29 @@ class World(Task): device=self.device, ) - nb_correct += (new_quizzes == result).long().min(dim=-1).values + l = self.height * self.width + direction = new_quizzes[:, l : l + 1] + direction = world.token_forward * ( + direction == world.token_backward + ) + world.token_backward * (direction == world.token_forward) + inverted_quizzes = torch.cat( + [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1 + ) + + inverted_result = inverted_quizzes.clone() + + masked_inplace_autoregression( + m, + self.batch_size, + inverted_result, + ar_mask, + deterministic_synthesis=True, + progress_bar_desc="solving reversed quizzes", + device=self.device, + ) + + nb_correct += (new_quizzes == result).long().min(dim=-1).values * ( + inverted_quizzes == inverted_result + ).long().min(dim=-1).values return new_quizzes, nb_correct