From 00d3d613cfd03df6bb56b034b55b9a98157d5e3e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 18:51:04 +0300 Subject: [PATCH] Update. --- lang.py | 78 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/lang.py b/lang.py index 5adf50f..4c3c64f 100755 --- a/lang.py +++ b/lang.py @@ -183,18 +183,20 @@ class Lang(problem.Problem): break return i1, j1, i2, j2 + ###################################################################### + def task_replace_color(self, A, f_A, B, f_B): c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 - for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: for _ in range(torch.randint(n, (1,)) + 1): i1, j1, i2, j2 = self.rec_coo(X) X[i1:i2, j1:j2] = c1 f_X[i1:i2, j1:j2] = c2 def task_move(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:1] + 1 di, dj = torch.randint(2, (2,)) * 2 - 1 - for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: + c = torch.randperm(len(self.colors) - 1)[:1] + 1 for _ in range(torch.randint(n, (1,)) + 1): while True: i1, j1, i2, j2 = self.rec_coo(X) @@ -210,20 +212,44 @@ class Lang(problem.Problem): f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c def task_grow(self, A, f_A, B, f_B): - c = torch.randperm(len(self.colors) - 1)[:1] + 1 - - for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + direction = torch.randint(2, (1,)) + for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: + c = torch.randperm(len(self.colors) - 1)[:1] + 1 for _ in range(torch.randint(n, (1,)) + 1): while True: i1, j1, i2, j2 = self.rec_coo(X) if i1 + 3 < i2 and j1 + 3 < j2: break - X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c + if direction == 0: + X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c + f_X[i1:i2, j1:j2] = c + else: + X[i1:i2, j1:j2] = c + f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c + + def task_frame(self, A, f_A, B, f_B): + direction = torch.randint(2, (1,)) + for n, X, f_X in [(1, A, f_A), (1, B, f_B)]: + c = torch.randperm(len(self.colors) - 1)[:1] + 1 + for _ in range(torch.randint(n, (1,)) + 1): + while True: + i1, j1, i2, j2 = self.rec_coo(X) + if i1 + 3 < i2 and j1 + 3 < j2: + break + X[i1:i2, j1:j2] = c f_X[i1:i2, j1:j2] = c + f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + + ###################################################################### def generate_prompts_and_answers(self, nb): - tasks = [self.task_replace_color, self.task_move, self.task_grow] + tasks = [ + self.task_replace_color, + self.task_move, + self.task_grow, + self.task_frame, + ] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -263,40 +289,14 @@ if __name__ == "__main__": prompts, answers = lang.generate_prompts_and_answers(36) - predicted_prompts = torch.rand(prompts.size(0)) < 0.5 - predicted_answers = torch.logical_not(predicted_prompts) + # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 + # predicted_answers = torch.logical_not(predicted_prompts) lang.save_quizzes( "/tmp", "test", prompts, - answers, # predicted_prompts, predicted_answers + answers, + # You can add a bool to put a frame around the predicted parts + # predicted_prompts, predicted_answers ) - - # start_time = time.perf_counter() - # token_sequences = lang.generate_token_sequences(nb=64) - # delay = time.perf_counter() - start_time - # print(f"{token_sequences.size(0)/delay:02f} seq/s") - - # print(lang.seq2str(seq[:4])) - - # for t in range(len(it[0])): - # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0) - # torchvision.utils.save_image( - # img.float() / 255.0, - # f"/tmp/frame_{t:03d}.png", - # nrow=8, - # padding=6, - # pad_value=0, - # ) - - # m = (torch.rand(seq.size()) < 0.05).long() - # seq = (1 - m) * seq + m * 23 - - # print(seq.size()) - # img = lang.seq2img(token_sequences) - # print(img.size()) - - # torchvision.utils.save_image( - # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0 - # ) -- 2.39.5