X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=snake.py;h=7c3494167f9ff04450e13c178f6fec2dc01206a0;hb=3f09462033feac19ad72ac1a4b8690e6330df22d;hp=eb46a076e5e548111c9ff100e73c1f904cc4d9d9;hpb=cf7fcbb7a946c4d1f4d29a28e0eb04940d3b0f76;p=picoclvr.git diff --git a/snake.py b/snake.py index eb46a07..7c34941 100755 --- a/snake.py +++ b/snake.py @@ -13,7 +13,7 @@ def generate_sequences( nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu") ): worlds = torch.randint(nb_colors, (nb, height, width), device=device) - nb_prior_visits = torch.zeros(nb, height, width, device=device) + world_prior_visits = torch.zeros(nb, height, width, device=device) # nb x 2 snake_position = torch.cat( @@ -70,17 +70,17 @@ def generate_sequences( snake_direction = snake_next_direction[i, j] sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4 - sequences_prior_visits[:, 2 * l] = nb_prior_visits[ + sequences_prior_visits[:, 2 * l] = world_prior_visits[ i, snake_position[:, 0], snake_position[:, 1] ] if l < prompt_length: - nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 + world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1 sequences[:, 2 * l + 1] = snake_direction # nb x 2 snake_position = snake_next_position[i, j] - return sequences, sequences_prior_visits + return sequences, sequences_prior_visits, worlds, world_prior_visits # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20) @@ -114,32 +114,42 @@ def solver(input, ar_mask): ###################################################################### if __name__ == "__main__": - for n in range(16): - descr = generate(nb=1, height=12, width=16) - - print(nb_properties(descr, height=12, width=16)) - - with open(f"picoclvr_example_{n:02d}.txt", "w") as f: - for d in descr: - f.write(f"{d}\n\n") - - img = descr2img(descr, height=12, width=16) - if img.size(0) == 1: - img = F.pad(img, (1, 1, 1, 1), value=64) + import cairo, numpy, math + + color_name2rgb = { + "red": [255, 0, 0], + "green": [0, 128, 0], + "blue": [0, 0, 255], + "yellow": [255, 255, 0], + "orange": [255, 128, 0], + "maroon": [128, 0, 0], + "dark_red": [139, 0, 0], + "brown": [165, 42, 42], + "firebrick": [178, 34, 34], + "crimson": [220, 20, 60], + "tomato": [255, 99, 71], + "coral": [255, 127, 80], + "indian_red": [205, 92, 92], + "light_coral": [240, 128, 128], + "dark_salmon": [233, 150, 122], + "salmon": [250, 128, 114], + } + + sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences( + 8, 6, 8, 5, 20, 10 + ) - torchvision.utils.save_image( - img / 255.0, - f"picoclvr_example_{n:02d}.png", - padding=1, - nrow=4, - pad_value=0.8, - ) + delta = 16 + height, width = sequences.size(0) * 16, sequences.size(1) * 16 + pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy() + surface = cairo.ImageSurface.create_for_data( + pixel_map, cairo.FORMAT_ARGB32, width, height + ) + ctx = cairo.Context(surface) + ctx.set_line_width(1.0) - import time + ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD) - start_time = time.perf_counter() - descr = generate(nb=1000, height=12, width=16) - end_time = time.perf_counter() - print(f"{len(descr) / (end_time - start_time):.02f} samples per second") + ctx.fill() ######################################################################