X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=lang.py;h=ce159e2d23ffca49085002a5d0268e09655e4b49;hb=5ecadfde470059278aec2b8ded217219e6773c04;hp=d53386c5761e917ebd9592e79875c50bc698c811;hpb=226589286bd8701002102062394909a82f5e807e;p=culture.git diff --git a/lang.py b/lang.py index d53386c..ce159e2 100755 --- a/lang.py +++ b/lang.py @@ -66,6 +66,9 @@ class Lang(problem.Problem): predicted_prompts=None, predicted_answers=None, ): + prompts = prompts.reshape(prompts.size(0), self.height, -1) + answers = answers.reshape(answers.size(0), self.height, -1) + if predicted_prompts is None: predicted_prompts = 255 @@ -89,7 +92,7 @@ class Lang(problem.Problem): y[...] = c else: c = c.long()[:, None] - c = c * torch.tensor([0, 0, 0], device=c.device) + ( + c = c * torch.tensor([192, 192, 192], device=c.device) + ( 1 - c ) * torch.tensor([255, 255, 255], device=c.device) y[...] = c[:, :, None, None] @@ -98,44 +101,66 @@ class Lang(problem.Problem): return y - margin = 4 - - img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) - h = img_prompts.size(2) - img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) - - img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) - img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) + margin = 8 - img_prompts = add_frame( - img_prompts, c=predicted_prompts, margin=margin, bottom=True + img_prompts = torch.cat( + [ + add_frame( + add_frame(self.frame2img(x), c=0, margin=1), + c=predicted_prompts, + margin=margin, + ) + for x in prompts.to("cpu").split(split_size=self.width, dim=2) + ], + dim=3, ) + + h = img_prompts.size(2) img_answers = add_frame( - img_answers, c=predicted_answers, margin=margin, bottom=True + add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1), + c=predicted_answers, + margin=margin, ) - marker_size = 16 + separator_size = 2 * margin separator = img_prompts.new_full( ( img_prompts.size(0), img_prompts.size(1), img_prompts.size(2), - marker_size, + separator_size, ), 255, ) - separator[:, :, 0] = 0 - separator[:, :, h - 1] = 0 - - for k in range(1, 2 * marker_size - 8): - i = k - (marker_size - 4) - j = marker_size - 5 - abs(i) - separator[:, :, h // 2 - 1 + i, 2 + j] = 0 - separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + marker = img_prompts.new_full( + ( + img_prompts.size(0), + img_prompts.size(1), + img_prompts.size(2), + separator_size, + ), + 255, + ) - img = torch.cat([img_prompts, separator, img_answers], dim=3) + # marker[:, :, 0] = 0 + # marker[:, :, h - 1] = 0 + + for k in range(1, 2 * separator_size - 8): + i = k - (separator_size - 4) + j = separator_size - 5 - abs(i) + marker[:, :, h // 2 - 1 + i, 2 + j] = 0 + marker[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + + img = torch.cat( + [ + img_prompts, + marker, + img_answers, + ], + dim=3, + ) image_name = os.path.join(result_dir, filename) torchvision.utils.save_image( @@ -158,26 +183,42 @@ class Lang(problem.Problem): break return i1, j1, i2, j2 - def task_red_to_green(self, A, f_A, B, f_B): + def task_replace_color(self, A, f_A, B, f_B): + c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 i1, j1, i2, j2 = self.rec_coo(A) - A[i1:i2, j1:j2] = self.name2color["red"] - f_A[i1:i2, j1:j2] = self.name2color["green"] - i1, j1, i2, j2 = self.rec_coo(B) - B[i1:i2, j1:j2] = self.name2color["red"] - f_B[i1:i2, j1:j2] = self.name2color["green"] + A[i1:i2, j1:j2] = c1 + f_A[i1:i2, j1:j2] = c2 + for _ in range(3): + i1, j1, i2, j2 = self.rec_coo(B) + B[i1:i2, j1:j2] = c1 + f_B[i1:i2, j1:j2] = c2 + + def move_color(self, A, f_A, B, f_B): + c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 + + i1, j1, i2, j2 = self.rec_coo(A) + A[i1:i2, j1:j2] = c1 + f_A[i1:i2, j1:j2] = c1 + + while True: + i1, j1, i2, j2 = self.rec_coo(A) + if i2 < self.height - 1: + break + A[i1:i2, j1:j2] = c2 + f_A[i1:i2, j1:j2] = c2 def generate_prompts_and_answers(self, nb): prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width for prompt, answer in zip(prompts, answers): - self.task_red_to_green( - prompt[:, 0 * w : 1 * w], - prompt[:, 1 * w : 2 * w], - prompt[:, 2 * w : 3 * w], - answer, - ) - return prompts, answers + A = prompt[:, 0 * w : 1 * w] + f_A = prompt[:, 1 * w : 2 * w] + B = prompt[:, 2 * w : 3 * w] + f_B = answer + # self.task_replace_color(A, f_A, B, f_B) + self.move_color(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes( self, @@ -207,11 +248,11 @@ if __name__ == "__main__": prompts, answers = lang.generate_prompts_and_answers(24) - # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - # predicted_answers = torch.rand(answers.size(0)) < 0.5 + predicted_prompts = torch.rand(prompts.size(0)) < 0.5 + predicted_answers = torch.logical_not(predicted_prompts) lang.save_quizzes( - "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers + "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers ) # start_time = time.perf_counter()