X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=tasks.py;h=ad952377ed6a6d1610bc2923985d42b295bc04d5;hb=55aedeed5cb8f0b61a625e64dcaeb0c1fd21d9f6;hp=50d541b11ccab4fb1dca97815412e1e88d69fc93;hpb=9047bd8185ed99c1302d8812551af3d5bd4602cb;p=culture.git diff --git a/tasks.py b/tasks.py index 50d541b..ad95237 100755 --- a/tasks.py +++ b/tasks.py @@ -14,9 +14,6 @@ from torch.nn import functional as F from mygpt import BracketedSequence -# from graph import save_attention_image -save_attention_image = None - ###################################################################### @@ -84,7 +81,7 @@ class World(Task): def save_image(self, input, result_dir, filename, logger): img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") def make_ar_mask(self, input): @@ -104,8 +101,8 @@ class World(Task): self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 + self.height = 7 + self.width = 9 self.train_input = world.generate( nb_train_samples, height=self.height, width=self.width @@ -115,6 +112,13 @@ class World(Task): nb_test_samples, height=self.height, width=self.width ).to(device) + # print() + # for a in world.seq2str(self.train_input): + # print(a) + # for a in world.seq2str(self.test_input): + # print(a) + # exit(0) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 self.train_quizzes = [] @@ -122,7 +126,7 @@ class World(Task): if result_dir is not None: self.save_image( - self.train_input[:96], result_dir, f"world_train.png", logger + self.train_input[:72], result_dir, f"world_train.png", logger ) def batches(self, split="train", desc=None): @@ -218,9 +222,9 @@ class World(Task): ) self.save_image( - result[:96], + result[:72], result_dir, - f"world_result_{n_epoch:04d}_{model.id:02d}.png", + f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, ) @@ -241,27 +245,46 @@ class World(Task): model, other_models, ): - new_quizzes = torch.empty( + ############################################################### + # Generate quizzes with model + + quizzes = torch.empty( nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 ) - ar_mask = torch.full(new_quizzes.size(), 1, device=self.device) + ar_mask = torch.full(quizzes.size(), 1, device=self.device) masked_inplace_autoregression( model, self.batch_size, - new_quizzes, + quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new quizzes", + progress_bar_desc="creating quizzes", device=self.device, ) - ar_mask = self.make_ar_mask(new_quizzes) + ############################################################### + # Create the reverse quizzes + + l = self.height * self.width + direction = quizzes[:, l : l + 1] + direction = world.token_forward * ( + direction == world.token_backward + ) + world.token_backward * (direction == world.token_forward) + reverse_quizzes = torch.cat( + [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 + ) + + ar_mask = self.make_ar_mask(quizzes) - nb_correct = 0 + ############################################################### + # Check how many of the other models can solve them in both + # directions + + nb_correct = [] for m in other_models: - result = new_quizzes.clone() + result = quizzes.clone() masked_inplace_autoregression( m, @@ -273,34 +296,31 @@ class World(Task): device=self.device, ) - l = self.height * self.width - direction = new_quizzes[:, l : l + 1] - direction = world.token_forward * ( - direction == world.token_backward - ) + world.token_backward * (direction == world.token_forward) - inverted_quizzes = torch.cat( - [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1 - ) + correct = (quizzes == result).long().min(dim=-1).values - inverted_result = inverted_quizzes.clone() + reverse_result = reverse_quizzes.clone() masked_inplace_autoregression( m, self.batch_size, - inverted_result, + reverse_result, ar_mask, deterministic_synthesis=True, - progress_bar_desc="solving reverse quizzes", + progress_bar_desc="solving reversed quizzes", device=self.device, ) - nb_correct += ( - ( - (new_quizzes == result).long() - * (inverted_quizzes, inverted_result).long() - ) - .min(dim=-1) - .values + reverse_correct = ( + (reverse_quizzes == reverse_result).long().min(dim=-1).values ) - return new_quizzes, nb_correct + nb_correct.append((correct * reverse_correct)[None, :]) + + nb_correct = torch.cat(nb_correct, dim=0) + + filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + with open(filename, "w") as f: + for k in nb_correct: + f.write(f"{k}\n") + + return quizzes, nb_correct.sum(dim=0)