y[...] = c
else:
c = c.long()[:, None]
- c = c * torch.tensor([0, 0, 0], device=c.device) + (
- 1 - c
- ) * torch.tensor([255, 255, 255], device=c.device)
+ c = (
+ (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+ + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+ + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+ )
y[...] = c[:, :, None, None]
y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x
prompts, answers = sky.generate_prompts_and_answers(4)
- predicted_prompts = torch.rand(prompts.size(0)) < 0.5
- predicted_answers = torch.rand(answers.size(0)) < 0.5
+ predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
+ predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
sky.save_quizzes(
"/tmp", "test", prompts, answers, predicted_prompts, predicted_answers