X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=ad952377ed6a6d1610bc2923985d42b295bc04d5;hb=908351dd77e8a703fb55b32a209c2fca4f551669;hp=cb5900b74877e1582e9a8f1cef0dd76c8ea61814;hpb=1506fb905b0f83034107e8e8dc336d10bdb1a7a7;p=culture.git diff --git a/tasks.py b/tasks.py index cb5900b..ad95237 100755 --- a/tasks.py +++ b/tasks.py @@ -81,7 +81,7 @@ class World(Task): def save_image(self, input, result_dir, filename, logger): img = world.sample2img(input.to("cpu"), self.height, self.width) image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) logger(f"wrote {image_name}") def make_ar_mask(self, input): @@ -101,8 +101,8 @@ class World(Task): self.batch_size = batch_size self.device = device - self.height = 6 - self.width = 8 + self.height = 7 + self.width = 9 self.train_input = world.generate( nb_train_samples, height=self.height, width=self.width @@ -112,6 +112,13 @@ class World(Task): nb_test_samples, height=self.height, width=self.width ).to(device) + # print() + # for a in world.seq2str(self.train_input): + # print(a) + # for a in world.seq2str(self.test_input): + # print(a) + # exit(0) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 self.train_quizzes = [] @@ -119,7 +126,7 @@ class World(Task): if result_dir is not None: self.save_image( - self.train_input[:96], result_dir, f"world_train.png", logger + self.train_input[:72], result_dir, f"world_train.png", logger ) def batches(self, split="train", desc=None): @@ -215,7 +222,7 @@ class World(Task): ) self.save_image( - result[:96], + result[:72], result_dir, f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, @@ -274,7 +281,7 @@ class World(Task): # Check how many of the other models can solve them in both # directions - nb_correct = 0 + nb_correct = [] for m in other_models: result = quizzes.clone() @@ -307,6 +314,13 @@ class World(Task): (reverse_quizzes == reverse_result).long().min(dim=-1).values ) - nb_correct += correct * reverse_correct + nb_correct.append((correct * reverse_correct)[None, :]) + + nb_correct = torch.cat(nb_correct, dim=0) + + filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + with open(filename, "w") as f: + for k in nb_correct: + f.write(f"{k}\n") - return quizzes, nb_correct + return quizzes, nb_correct.sum(dim=0)