- c = c * torch.tensor([0, 0, 0], device=c.device) + (
- 1 - c
- ) * torch.tensor([255, 255, 255], device=c.device)
+ c = (
+ (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+ + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+ + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+ )