Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index a90e37d..cb25ea0 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -5,7 +5,7 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
-import math, sys, tqdm
+import math, sys, tqdm, os
 
 import torch, torchvision
 
@@ -15,6 +15,17 @@ from torch.nn import functional as F
 ######################################################################
 
 
+class Problem:
+    def generate_seq(self, nb_train_samples):
+        pass
+
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        pass
+
+    def direction_tokens(self):
+        pass
+
+
 class Sky:
     colors = torch.tensor(
         [
@@ -48,6 +59,9 @@ class Sky:
         self.nb_birds = nb_birds
         self.nb_iterations = nb_iterations
 
+    def direction_tokens(self):
+        return self.token_forward, self.token_backward
+
     def generate_seq(self, nb, return_iterations=False):
         pairs = []
         kept_iterations = []
@@ -338,6 +352,15 @@ class Sky:
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
+    def save_image(self, input, result_dir, filename, logger):
+        img = self.seq2img(input.to("cpu"))
+        image_name = os.path.join(result_dir, filename)
+        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+        logger(f"wrote {image_name}")
+
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
 
 ######################################################################