("red", [255, 0, 0]),
("green", [0, 192, 0]),
("blue", [0, 0, 255]),
- ("orange", [255, 192, 0]),
+ ("yellow", [255, 224, 0]),
("cyan", [0, 255, 255]),
- ("violet", [255, 0, 255]),
+ ("violet", [224, 128, 255]),
("lightgreen", [192, 255, 192]),
("brown", [165, 42, 42]),
("lightblue", [192, 192, 255]),
def __init__(self, device=torch.device("cpu")):
self.colors = torch.tensor([c for _, c in self.named_colors])
- self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)])
self.height = 10
self.width = 10
self.device = device
return x
- def frame2img_(self, x, scale=15):
- x = x.reshape(x.size(0), self.height, -1)
- x = self.colors[x].permute(0, 3, 1, 2)
- s = x.shape
- x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
- x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
-
- x[:, :, :, torch.arange(0, x.size(3), scale)] = 0
- x[:, :, torch.arange(0, x.size(2), scale), :] = 0
- x = x[:, :, 1:, 1:]
-
- return x
-
def save_image(
self,
result_dir,