X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=sky.py;h=ac6cbdc71f0e476fcf1b8c27b2d0407ca9ea533a;hb=f78cdbad69a877df92df41094a9f3f1036a1582a;hp=a90e37d542aebd5260b3dbdb7118c0ba4e2cc33f;hpb=b76e3f632315c63dbd8f11a53b187f23057e4e1f;p=culture.git diff --git a/sky.py b/sky.py index a90e37d..ac6cbdc 100755 --- a/sky.py +++ b/sky.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm +import math, sys, tqdm, os import torch, torchvision @@ -14,8 +14,10 @@ from torch.nn import functional as F ###################################################################### +import problem -class Sky: + +class Sky(problem.Problem): colors = torch.tensor( [ [255, 255, 255], @@ -48,6 +50,9 @@ class Sky: self.nb_birds = nb_birds self.nb_iterations = nb_iterations + def direction_tokens(self): + return self.token_forward, self.token_backward + def generate_seq(self, nb, return_iterations=False): pairs = [] kept_iterations = [] @@ -253,31 +258,34 @@ class Sky: return torch.cat(result, dim=0) - def frame2img(self, x, upscale=15): + def frame2img(self, x, scale=15): x = x.reshape(-1, self.height, self.width) m = torch.logical_and( x >= 0, x < self.first_bird_token + self.nb_bird_tokens ).long() x = self.colors[x * m].permute(0, 3, 1, 2) s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) - x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 - x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 x = x[:, :, 1:, 1:] for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): if m[n, i, j] == 0: - for k in range(2, upscale - 2): - x[n, :, i * upscale + k, j * upscale + k] = 0 - x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 return x - def seq2img(self, seq, upscale=15): + def seq2img(self, seq, scale=15): f_first = seq[:, : self.height * self.width].reshape( -1, self.height, self.width ) @@ -287,47 +295,53 @@ class Sky: direction = seq[:, self.height * self.width] direction_symbol = torch.full( - (direction.size(0), self.height * upscale - 1, upscale), 0 + (direction.size(0), self.height * scale - 1, scale), 0 ) direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2) - separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0) + separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0) for n in range(direction_symbol.size(0)): if direction[n] == self.token_forward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + upscale // 2 - abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + scale // 2 - abs(k - scale // 2), + ] = 0 elif direction[n] == self.token_backward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + abs(k - scale // 2), + ] = 0 else: - for k in range(2, upscale - 2): - direction_symbol[ - n, :, (self.height * upscale) // 2 - upscale // 2 + k, k - ] = 0 - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - upscale - 1 - k, - ] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + k, + ] = 0 + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + scale - 1 - k, + ] = 0 return torch.cat( [ - self.frame2img(f_first, upscale), + self.frame2img(f_first, scale), separator, direction_symbol, separator, - self.frame2img(f_second, upscale), + self.frame2img(f_second, scale), ], dim=3, ) @@ -338,6 +352,14 @@ class Sky: result.append("".join([self.token2char[v] for v in s])) return result + def save_image(self, input, result_dir, filename): + img = self.seq2img(input.to("cpu")) + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) + + def save_quizzes(self, input, result_dir, filename_prefix): + self.save_image(input, result_dir, filename_prefix + ".png") + ###################################################################### @@ -349,7 +371,7 @@ if __name__ == "__main__": start_time = time.perf_counter() seq, it = sky.generate_seq(nb=64, return_iterations=True) delay = time.perf_counter() - start_time - print(f"{seq.size(0)/delay:02f} samples/s") + print(f"{seq.size(0)/delay:02f} seq/s") print(sky.seq2str(seq[:4]))