X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=ed440d37a37af02b6b2a19cb6ce20945d3d59afe;hb=30c76210e3ed2704b2a059208f385cb623c1486d;hp=d2a4568919c31f097966979988914a529652880a;hpb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;p=culture.git diff --git a/sky.py b/sky.py index d2a4568..ed440d3 100755 --- a/sky.py +++ b/sky.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, tqdm, os +import math, sys, tqdm, os, warnings import torch, torchvision @@ -42,9 +42,6 @@ class Sky(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def nb_token_values(self): - return len(self.colors) - def __init__( self, height=6, @@ -155,17 +152,8 @@ class Sky(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - return prompts, answers - - ###################################################################### - def frame2img(self, x, scale=15): - x = x.reshape(-1, self.height, self.width) + x = x.reshape(x.size(0), self.height, -1) m = torch.logical_and( x >= 0, x < self.first_bird_token + self.nb_bird_tokens ).long() @@ -212,44 +200,102 @@ class Sky(problem.Problem): if predicted_answers is None: predicted_answers = 255 - def add_frame(x, c, margin): - y = x.new_full( - (x.size(0), x.size(1), x.size(2) + 2 * margin, x.size(3) + 2 * margin), - 0, - ) + def add_frame(x, c, margin, bottom=False): + if bottom: + h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 + else: + h, w, di, dj = ( + x.size(2) + 2 * margin, + x.size(3) + 2 * margin, + margin, + margin, + ) + + y = x.new_full((x.size(0), x.size(1), h, w), 0) + if type(c) is int: y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([192, 192, 192], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] - y[:, :, margin:-margin, margin:-margin] = x + + y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x + return y margin = 4 - img_prompts = add_frame(self.frame2img(prompts.to("cpu")), 0, 1) - img_answers = add_frame(self.frame2img(answers.to("cpu")), 0, 1) + img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) + h = img_prompts.size(2) + img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) - # img_prompts = add_frame(img_prompts, 255, margin) - # img_answers = add_frame(img_answers, 255, margin) + img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) + img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) - img_prompts = add_frame(img_prompts, predicted_prompts, margin) - img_answers = add_frame(img_answers, predicted_answers, margin) + img_prompts = add_frame( + img_prompts, c=predicted_prompts, margin=margin, bottom=True + ) + img_answers = add_frame( + img_answers, c=predicted_answers, margin=margin, bottom=True + ) + + marker_size = 16 separator = img_prompts.new_full( - (img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), margin), 255 + ( + img_prompts.size(0), + img_prompts.size(1), + img_prompts.size(2), + marker_size, + ), + 255, ) - img = torch.cat([img_prompts, img_answers], dim=3) + separator[:, :, 0] = 0 + separator[:, :, h - 1] = 0 + + for k in range(1, 2 * marker_size - 8): + i = k - (marker_size - 4) + j = marker_size - 5 - abs(i) + separator[:, :, h // 2 - 1 + i, 2 + j] = 0 + separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + + img = torch.cat([img_prompts, separator, img_answers], dim=3) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 + img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 ) + ###################################################################### + + def nb_token_values(self): + return len(self.colors) + + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) + + prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) + + answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + + # warnings.warn("dirty test with longer answer", RuntimeWarning) + # answers = torch.cat( + # [ + # frame_sequences[:, frame_sequences.size(1) // 2 :], + # frame_sequences[:, frame_sequences.size(1) // 2 :], + # ], + # dim=3, + # ).flatten(1) + + return prompts, answers + def save_quizzes( self, result_dir, @@ -274,12 +320,12 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=4, nb_iterations=2) + sky = Sky(height=6, width=8, speed=1, nb_iterations=4) prompts, answers = sky.generate_prompts_and_answers(4) - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 + predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 sky.save_quizzes( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers