Update.
[culture.git] / sky.py
diff --git a/sky.py b/sky.py
index 6ef8a3a..ed440d3 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -217,9 +217,11 @@ class Sky(problem.Problem):
                 y[...] = c
             else:
                 c = c.long()[:, None]
-                c = c * torch.tensor([0, 0, 0], device=c.device) + (
-                    1 - c
-                ) * torch.tensor([255, 255, 255], device=c.device)
+                c = (
+                    (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+                    + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+                    + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+                )
                 y[...] = c[:, :, None, None]
 
             y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
@@ -322,8 +324,8 @@ if __name__ == "__main__":
 
     prompts, answers = sky.generate_prompts_and_answers(4)
 
-    predicted_prompts = torch.rand(prompts.size(0)) < 0.5
-    predicted_answers = torch.rand(answers.size(0)) < 0.5
+    predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
+    predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
 
     sky.save_quizzes(
         "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers