# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm
+import math, sys, tqdm, os
import torch, torchvision
######################################################################
+import problem
-class Sky:
+
+class Sky(problem.Problem):
colors = torch.tensor(
[
[255, 255, 255],
self.nb_birds = nb_birds
self.nb_iterations = nb_iterations
+ def direction_tokens(self):
+ return self.token_forward, self.token_backward
+
def generate_seq(self, nb, return_iterations=False):
pairs = []
kept_iterations = []
result.append("".join([self.token2char[v] for v in s]))
return result
+ def save_image(self, input, result_dir, filename):
+ img = self.seq2img(input.to("cpu"))
+ image_name = os.path.join(result_dir, filename)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+
+ def save_quizzes(self, input, result_dir, filename_prefix):
+ self.save_image(input, result_dir, filename_prefix + ".png")
+
######################################################################