nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
):
worlds = torch.randint(nb_colors, (nb, height, width), device=device)
- nb_prior_visits = torch.zeros(nb, height, width, device=device)
+ world_prior_visits = torch.zeros(nb, height, width, device=device)
# nb x 2
snake_position = torch.cat(
snake_direction = snake_next_direction[i, j]
sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
- sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+ sequences_prior_visits[:, 2 * l] = world_prior_visits[
i, snake_position[:, 0], snake_position[:, 1]
]
if l < prompt_length:
- nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
+ world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
sequences[:, 2 * l + 1] = snake_direction
# nb x 2
snake_position = snake_next_position[i, j]
- return sequences, sequences_prior_visits
+ return sequences, sequences_prior_visits, worlds, world_prior_visits
# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
######################################################################
if __name__ == "__main__":
- for n in range(16):
- descr = generate(nb=1, height=12, width=16)
-
- print(nb_properties(descr, height=12, width=16))
-
- with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
- for d in descr:
- f.write(f"{d}\n\n")
-
- img = descr2img(descr, height=12, width=16)
- if img.size(0) == 1:
- img = F.pad(img, (1, 1, 1, 1), value=64)
+ import cairo, numpy, math
+
+ color_name2rgb = {
+ "red": [255, 0, 0],
+ "green": [0, 128, 0],
+ "blue": [0, 0, 255],
+ "yellow": [255, 255, 0],
+ "orange": [255, 128, 0],
+ "maroon": [128, 0, 0],
+ "dark_red": [139, 0, 0],
+ "brown": [165, 42, 42],
+ "firebrick": [178, 34, 34],
+ "crimson": [220, 20, 60],
+ "tomato": [255, 99, 71],
+ "coral": [255, 127, 80],
+ "indian_red": [205, 92, 92],
+ "light_coral": [240, 128, 128],
+ "dark_salmon": [233, 150, 122],
+ "salmon": [250, 128, 114],
+ }
+
+ sequences, sequences_prior_visits, worlds, world_prior_visits = generate_sequences(
+ 8, 6, 8, 5, 20, 10
+ )
- torchvision.utils.save_image(
- img / 255.0,
- f"picoclvr_example_{n:02d}.png",
- padding=1,
- nrow=4,
- pad_value=0.8,
- )
+ delta = 16
+ height, width = sequences.size(0) * 16, sequences.size(1) * 16
+ pixel_map = torch.ByteTensor(width, height, 4).fill_(0).numpy()
+ surface = cairo.ImageSurface.create_for_data(
+ pixel_map, cairo.FORMAT_ARGB32, width, height
+ )
+ ctx = cairo.Context(surface)
+ ctx.set_line_width(1.0)
- import time
+ ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
- start_time = time.perf_counter()
- descr = generate(nb=1000, height=12, width=16)
- end_time = time.perf_counter()
- print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
+ ctx.fill()
######################################################################