Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index d2a4568..2183cf1 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -165,7 +165,7 @@ class Sky(problem.Problem):
     ######################################################################
 
     def frame2img(self, x, scale=15):
-        x = x.reshape(-1, self.height, self.width)
+        x = x.reshape(x.size(0), self.height, -1)
         m = torch.logical_and(
             x >= 0, x < self.first_bird_token + self.nb_bird_tokens
         ).long()
@@ -274,7 +274,7 @@ class Sky(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
+    sky = Sky(height=6, width=8, speed=1, nb_iterations=4)
 
     prompts, answers = sky.generate_prompts_and_answers(4)