X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=ed440d37a37af02b6b2a19cb6ce20945d3d59afe;hb=9ec709a2a08eb82dfc17ef1e24aa9a84751d63e0;hp=6ef8a3af2184777c223dfd1803647a49bd3dd54d;hpb=00f321c2e2a9b7be1edcb1453bf0d45f52e50919;p=culture.git diff --git a/sky.py b/sky.py index 6ef8a3a..ed440d3 100755 --- a/sky.py +++ b/sky.py @@ -217,9 +217,11 @@ class Sky(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([0, 0, 0], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -322,8 +324,8 @@ if __name__ == "__main__": prompts, answers = sky.generate_prompts_and_answers(4) - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 + predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 sky.save_quizzes( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers