X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=sky.py;h=1e6ed4d26659256637b2bb3d874ce2997a94b814;hb=4f0057b363762698f90eea05de154e62b6883bd0;hp=a90e37d542aebd5260b3dbdb7118c0ba4e2cc33f;hpb=b76e3f632315c63dbd8f11a53b187f23057e4e1f;p=culture.git diff --git a/sky.py b/sky.py index a90e37d..1e6ed4d 100755 --- a/sky.py +++ b/sky.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm +import math, sys, tqdm, os import torch, torchvision @@ -14,8 +14,10 @@ from torch.nn import functional as F ###################################################################### +import problem -class Sky: + +class Sky(problem.Problem): colors = torch.tensor( [ [255, 255, 255], @@ -48,6 +50,9 @@ class Sky: self.nb_birds = nb_birds self.nb_iterations = nb_iterations + def direction_tokens(self): + return self.token_forward, self.token_backward + def generate_seq(self, nb, return_iterations=False): pairs = [] kept_iterations = [] @@ -338,6 +343,14 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result + def save_image(self, input, result_dir, filename): + img = self.seq2img(input.to("cpu")) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) + + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") + ######################################################################