From: François Fleuret Date: Wed, 3 Jul 2024 14:40:18 +0000 (+0300) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ff436fbfeea7830a600c81175b5376d02925a1a0;p=culture.git Update. --- diff --git a/lang.py b/lang.py index d53386c..3d939bb 100755 --- a/lang.py +++ b/lang.py @@ -73,6 +73,7 @@ class Lang(problem.Problem): predicted_answers = 255 def add_frame(x, c, margin, bottom=False): + print(f"{type(x)=} {type(c)=}") if bottom: h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 else: @@ -89,7 +90,7 @@ class Lang(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([0, 0, 0], device=c.device) + ( + c = c * torch.tensor([192, 192, 192], device=c.device) + ( 1 - c ) * torch.tensor([255, 255, 255], device=c.device) y[...] = c[:, :, None, None] @@ -98,44 +99,66 @@ class Lang(problem.Problem): return y - margin = 4 + margin = 8 - img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) - h = img_prompts.size(2) - img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) - - img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) - img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) - - img_prompts = add_frame( - img_prompts, c=predicted_prompts, margin=margin, bottom=True + img_prompts = torch.cat( + [ + add_frame( + add_frame(self.frame2img(x), c=0, margin=1), + c=predicted_prompts, + margin=margin, + ) + for x in prompts.to("cpu").split(split_size=self.width, dim=2) + ], + dim=3, ) + + h = img_prompts.size(2) img_answers = add_frame( - img_answers, c=predicted_answers, margin=margin, bottom=True + add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1), + c=predicted_answers, + margin=margin, ) - marker_size = 16 + separator_size = 2 * margin separator = img_prompts.new_full( ( img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), - marker_size, + separator_size, ), 255, ) - separator[:, :, 0] = 0 - separator[:, :, h - 1] = 0 - - for k in range(1, 2 * marker_size - 8): - i = k - (marker_size - 4) - j = marker_size - 5 - abs(i) - separator[:, :, h // 2 - 1 + i, 2 + j] = 0 - separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + marker = img_prompts.new_full( + ( + img_prompts.size(0), + img_prompts.size(1), + img_prompts.size(2), + separator_size, + ), + 255, + ) - img = torch.cat([img_prompts, separator, img_answers], dim=3) + # marker[:, :, 0] = 0 + # marker[:, :, h - 1] = 0 + + for k in range(1, 2 * separator_size - 8): + i = k - (separator_size - 4) + j = separator_size - 5 - abs(i) + marker[:, :, h // 2 - 1 + i, 2 + j] = 0 + marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + + img = torch.cat( + [ + img_prompts, + marker, + img_answers, + ], + dim=3, + ) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( @@ -207,11 +230,11 @@ if __name__ == "__main__": prompts, answers = lang.generate_prompts_and_answers(24) - # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - # predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.rand(prompts.size(0)) < 0.5 + predicted_answers = torch.logical_not(predicted_prompts) lang.save_quizzes( - "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers + "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers ) # start_time = time.perf_counter()