X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=sky.py;h=6ef8a3af2184777c223dfd1803647a49bd3dd54d;hb=226589286bd8701002102062394909a82f5e807e;hp=e509fb775b8a2e01060de5907923d6cc6cfed523;hpb=d13eb42426139c4e59506db03f6b2bd68ade9b90;p=culture.git diff --git a/sky.py b/sky.py index e509fb7..6ef8a3a 100755 --- a/sky.py +++ b/sky.py @@ -229,6 +229,7 @@ class Sky(problem.Problem): margin = 4 img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) + h = img_prompts.size(2) img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) @@ -241,7 +242,7 @@ class Sky(problem.Problem): img_answers, c=predicted_answers, margin=margin, bottom=True ) - marker_size = 8 + marker_size = 16 separator = img_prompts.new_full( ( @@ -253,17 +254,20 @@ class Sky(problem.Problem): 255, ) - for k in range(2, 2 * marker_size - 3): - i = k + 1 - marker_size - j = marker_size - 2 - abs(k - marker_size + 1) - separator[:, :, separator.size(2) // 2 + i, j] = 0 - separator[:, :, separator.size(2) // 2 + i + 1, j] = 0 + separator[:, :, 0] = 0 + separator[:, :, h - 1] = 0 + + for k in range(1, 2 * marker_size - 8): + i = k - (marker_size - 4) + j = marker_size - 5 - abs(i) + separator[:, :, h // 2 - 1 + i, 2 + j] = 0 + separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 img = torch.cat([img_prompts, separator, img_answers], dim=3) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 + img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 ) ###################################################################### @@ -278,6 +282,7 @@ class Sky(problem.Problem): prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + # warnings.warn("dirty test with longer answer", RuntimeWarning) # answers = torch.cat( # [