X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=grids.py;h=247c146d16140ed24cc636d1baf74dc21aeefbeb;hb=c9c9df43a6b97e5b3e81c8cf05d2f1b3010dea05;hp=9462f87676687503576c030b1ac111c9a1fde0c3;hpb=b269012c29f1eeae7d51625694269da40326a69f;p=culture.git diff --git a/grids.py b/grids.py index 9462f87..247c146 100755 --- a/grids.py +++ b/grids.py @@ -23,9 +23,9 @@ class Grids(problem.Problem): ("red", [255, 0, 0]), ("green", [0, 192, 0]), ("blue", [0, 0, 255]), - ("orange", [255, 192, 0]), + ("yellow", [255, 224, 0]), ("cyan", [0, 255, 255]), - ("violet", [255, 0, 255]), + ("violet", [224, 128, 255]), ("lightgreen", [192, 255, 192]), ("brown", [165, 42, 42]), ("lightblue", [192, 192, 255]), @@ -34,7 +34,6 @@ class Grids(problem.Problem): def __init__(self, device=torch.device("cpu")): self.colors = torch.tensor([c for _, c in self.named_colors]) - self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) self.height = 10 self.width = 10 self.device = device @@ -66,19 +65,6 @@ class Grids(problem.Problem): return x - def frame2img_(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) - x = self.colors[x].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] - - return x - def save_image( self, result_dir,