X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=cb25ea0ec335fc1823d2f4d7d044ed7908184a0f;hb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;hp=a90e37d542aebd5260b3dbdb7118c0ba4e2cc33f;hpb=b76e3f632315c63dbd8f11a53b187f23057e4e1f;p=culture.git diff --git a/sky.py b/sky.py index a90e37d..cb25ea0 100755 --- a/sky.py +++ b/sky.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm +import math, sys, tqdm, os import torch, torchvision @@ -15,6 +15,17 @@ from torch.nn import functional as F ###################################################################### +class Problem: + def generate_seq(self, nb_train_samples): + pass + + def save_quizzes(self, input, result_dir, filename_prefix, logger): + pass + + def direction_tokens(self): + pass + + class Sky: colors = torch.tensor( [ @@ -48,6 +59,9 @@ class Sky: self.nb_birds = nb_birds self.nb_iterations = nb_iterations + def direction_tokens(self): + return self.token_forward, self.token_backward + def generate_seq(self, nb, return_iterations=False): pairs = [] kept_iterations = [] @@ -338,6 +352,15 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result + def save_image(self, input, result_dir, filename, logger): + img = self.seq2img(input.to("cpu")) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) + logger(f"wrote {image_name}") + + def save_quizzes(self, input, result_dir, filename_prefix, logger): + self.save_image(input, result_dir, filename_prefix + ".png", logger) + ######################################################################