X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=1e6ed4d26659256637b2bb3d874ce2997a94b814;hb=45018258b05a979037477b978ef0c55fd444af7e;hp=cb25ea0ec335fc1823d2f4d7d044ed7908184a0f;hpb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;p=culture.git diff --git a/sky.py b/sky.py index cb25ea0..1e6ed4d 100755 --- a/sky.py +++ b/sky.py @@ -14,19 +14,10 @@ from torch.nn import functional as F ###################################################################### +import problem -class Problem: - def generate_seq(self, nb_train_samples): - pass - def save_quizzes(self, input, result_dir, filename_prefix, logger): - pass - - def direction_tokens(self): - pass - - -class Sky: +class Sky(problem.Problem): colors = torch.tensor( [ [255, 255, 255], @@ -352,14 +343,13 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result - def save_image(self, input, result_dir, filename, logger): + def save_image(self, input, result_dir, filename): img = self.seq2img(input.to("cpu")) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - def save_quizzes(self, input, result_dir, filename_prefix, logger): - self.save_image(input, result_dir, filename_prefix + ".png", logger) + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") ######################################################################