From 74311726e42dccb8bc096e86a7e9000576099bab Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 21 Jun 2023 22:10:48 +0200 Subject: [PATCH] Update. --- main.py | 10 ++++----- snake.py | 66 ++++++++++++++++++++++++++++++++------------------------ 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 9679236..7cb8d4f 100755 --- a/main.py +++ b/main.py @@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8) parser.add_argument("--snake_nb_colors", type=int, default=5) -parser.add_argument("--snake_length", type=int, default=400) +parser.add_argument("--snake_length", type=int, default=200) ###################################################################### @@ -143,8 +143,8 @@ default_args = { "batch_size": 25, }, "snake": { - "nb_epochs": 25, - "batch_size": 20, + "nb_epochs": 5, + "batch_size": 25, }, } @@ -689,7 +689,7 @@ class TaskSnake(Task): self.device = device self.prompt_length = prompt_length - self.train_input, self.train_prior_visits = snake.generate_sequences( + self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences( nb_train_samples, height, width, @@ -698,7 +698,7 @@ class TaskSnake(Task): prompt_length, self.device, ) - self.test_input, self.test_prior_visits = snake.generate_sequences( + self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences( nb_test_samples, height, width, diff --git a/snake.py b/snake.py index eb46a07..7c34941 100755 --- a/snake.py +++ b/snake.py @@ -13,7 +13,7 @@ def generate_sequences( nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") ): worlds = torch.randint(nb_colors, (nb, height, width), device=device) - nb_prior_visits = torch.zeros(nb, height, width, device=device) + world_prior_visits = torch.zeros(nb, height, width, device=device) # nb x 2 snake_position = torch.cat( @@ -70,17 +70,17 @@ def generate_sequences( snake_direction = snake_next_direction[i, j] sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 - sequences_prior_visits[:, 2 * l] = nb_prior_visits[ + sequences_prior_visits[:, 2 * l] = world_prior_visits[ i, snake_position[:, 0], snake_position[:, 1] ] if l < prompt_length: - nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 sequences[:, 2 * l + 1] = snake_direction # nb x 2 snake_position = snake_next_position[i, j] - return sequences, sequences_prior_visits + return sequences, sequences_prior_visits, worlds, world_prior_visits # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) @@ -114,32 +114,42 @@ def solver(input, ar_mask): ###################################################################### if __name__ == "__main__": - for n in range(16): - descr = generate(nb=1, height=12, width=16) - - print(nb_properties(descr, height=12, width=16)) - - with open(f"picoclvr_example_{n:02d}.txt", "w") as f: - for d in descr: - f.write(f"{d}\n\n") - - img = descr2img(descr, height=12, width=16) - if img.size(0) == 1: - img = F.pad(img, (1, 1, 1, 1), value=64) + import cairo, numpy, math + + color_name2rgb = { + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "orange": [255, 128, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + } + + sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences( + 8, 6, 8, 5, 20, 10 + ) - torchvision.utils.save_image( - img / 255.0, - f"picoclvr_example_{n:02d}.png", - padding=1, - nrow=4, - pad_value=0.8, - ) + delta = 16 + height, width = sequences.size(0) * 16, sequences.size(1) * 16 + pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy() + surface = cairo.ImageSurface.create_for_data( + pixel_map, cairo.FORMAT_ARGB32, width, height + ) + ctx = cairo.Context(surface) + ctx.set_line_width(1.0) - import time + ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD) - start_time = time.perf_counter() - descr = generate(nb=1000, height=12, width=16) - end_time = time.perf_counter() - print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + ctx.fill() ###################################################################### -- 2.39.5