X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=sky.py;h=ac6cbdc71f0e476fcf1b8c27b2d0407ca9ea533a;hb=8e37a868ac7dfc1cb5e924790929c6eebabbeb94;hp=cb25ea0ec335fc1823d2f4d7d044ed7908184a0f;hpb=c8979c695ad584c54d605b8f183e5d2e99f2d1cc;p=culture.git diff --git a/sky.py b/sky.py index cb25ea0..ac6cbdc 100755 --- a/sky.py +++ b/sky.py @@ -14,19 +14,10 @@ from torch.nn import functional as F ###################################################################### +import problem -class Problem: - def generate_seq(self, nb_train_samples): - pass - def save_quizzes(self, input, result_dir, filename_prefix, logger): - pass - - def direction_tokens(self): - pass - - -class Sky: +class Sky(problem.Problem): colors = torch.tensor( [ [255, 255, 255], @@ -267,31 +258,34 @@ class Sky: return torch.cat(result, dim=0) - def frame2img(self, x, upscale=15): + def frame2img(self, x, scale=15): x = x.reshape(-1, self.height, self.width) m = torch.logical_and( x >= 0, x < self.first_bird_token + self.nb_bird_tokens ).long() x = self.colors[x * m].permute(0, 3, 1, 2) s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) - x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 - x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 x = x[:, :, 1:, 1:] for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): if m[n, i, j] == 0: - for k in range(2, upscale - 2): - x[n, :, i * upscale + k, j * upscale + k] = 0 - x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 return x - def seq2img(self, seq, upscale=15): + def seq2img(self, seq, scale=15): f_first = seq[:, : self.height * self.width].reshape( -1, self.height, self.width ) @@ -301,47 +295,53 @@ class Sky: direction = seq[:, self.height * self.width] direction_symbol = torch.full( - (direction.size(0), self.height * upscale - 1, upscale), 0 + (direction.size(0), self.height * scale - 1, scale), 0 ) direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2) - separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0) + separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0) for n in range(direction_symbol.size(0)): if direction[n] == self.token_forward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + upscale // 2 - abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + scale // 2 - abs(k - scale // 2), + ] = 0 elif direction[n] == self.token_backward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + abs(k - scale // 2), + ] = 0 else: - for k in range(2, upscale - 2): - direction_symbol[ - n, :, (self.height * upscale) // 2 - upscale // 2 + k, k - ] = 0 - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - upscale - 1 - k, - ] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + k, + ] = 0 + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + scale - 1 - k, + ] = 0 return torch.cat( [ - self.frame2img(f_first, upscale), + self.frame2img(f_first, scale), separator, direction_symbol, separator, - self.frame2img(f_second, upscale), + self.frame2img(f_second, scale), ], dim=3, ) @@ -352,14 +352,13 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result - def save_image(self, input, result_dir, filename, logger): + def save_image(self, input, result_dir, filename): img = self.seq2img(input.to("cpu")) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - def save_quizzes(self, input, result_dir, filename_prefix, logger): - self.save_image(input, result_dir, filename_prefix + ".png", logger) + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") ###################################################################### @@ -372,7 +371,7 @@ if __name__ == "__main__": start_time = time.perf_counter() seq, it = sky.generate_seq(nb=64, return_iterations=True) delay = time.perf_counter() - start_time - print(f"{seq.size(0)/delay:02f} samples/s") + print(f"{seq.size(0)/delay:02f} seq/s") print(sky.seq2str(seq[:4]))