X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=1768a81cd981323dd85d6200c34c704654e59fd2;hb=050976a525fee2d3b824350a3058ab7299a2bd3d;hp=e509fb775b8a2e01060de5907923d6cc6cfed523;hpb=d13eb42426139c4e59506db03f6b2bd68ade9b90;p=culture.git diff --git a/sky.py b/sky.py index e509fb7..1768a81 100755 --- a/sky.py +++ b/sky.py @@ -50,7 +50,11 @@ class Sky(problem.Problem): speed=2, nb_iterations=2, avoid_collision=True, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, ): + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) self.height = height self.width = width self.nb_birds = nb_birds @@ -217,9 +221,11 @@ class Sky(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([0, 0, 0], device=c.device) + ( - 1 - c - ) * torch.tensor([255, 255, 255], device=c.device) + c = ( + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + ) y[...] = c[:, :, None, None] y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x @@ -229,6 +235,7 @@ class Sky(problem.Problem): margin = 4 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) + h = img_prompts.size(2) img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) @@ -241,7 +248,7 @@ class Sky(problem.Problem): img_answers, c=predicted_answers, margin=margin, bottom=True ) - marker_size = 8 + marker_size = 16 separator = img_prompts.new_full( ( @@ -253,17 +260,20 @@ class Sky(problem.Problem): 255, ) - for k in range(2, 2 * marker_size - 3): - i = k + 1 - marker_size - j = marker_size - 2 - abs(k - marker_size + 1) - separator[:, :, separator.size(2) // 2 + i, j] = 0 - separator[:, :, separator.size(2) // 2 + i + 1, j] = 0 + separator[:, :, 0] = 0 + separator[:, :, h - 1] = 0 + + for k in range(1, 2 * marker_size - 8): + i = k - (marker_size - 4) + j = marker_size - 5 - abs(i) + separator[:, :, h // 2 - 1 + i, 2 + j] = 0 + separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 img = torch.cat([img_prompts, separator, img_answers], dim=3) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 + img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 ) ###################################################################### @@ -278,6 +288,7 @@ class Sky(problem.Problem): prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + # warnings.warn("dirty test with longer answer", RuntimeWarning) # answers = torch.cat( # [ @@ -317,8 +328,8 @@ if __name__ == "__main__": prompts, answers = sky.generate_prompts_and_answers(4) - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 + predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 sky.save_quizzes( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers