3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import torch, torchvision
9 import torch.nn.functional as F
191 "medium_spring_green",
194 "medium_aqua_marine",
251 "light_golden_rod_yellow",
287 color_id = dict([(n, k) for k, n in enumerate(color_names)])
288 color_tokens = dict([(n, c) for n, c in zip(color_names, colors)])
290 ######################################################################
293 def all_properties(height, width, nb_squares, square_i, square_j, square_c):
296 for r, c_r in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
297 s += [f"there is {c_r}"]
299 if square_i[r] >= height - height // 3:
300 s += [f"{c_r} bottom"]
301 if square_i[r] < height // 3:
303 if square_j[r] >= width - width // 3:
304 s += [f"{c_r} right"]
305 if square_j[r] < width // 3:
308 for t, c_t in [(k, color_names[square_c[k]]) for k in range(nb_squares)]:
309 if square_i[r] > square_i[t]:
310 s += [f"{c_r} below {c_t}"]
311 if square_i[r] < square_i[t]:
312 s += [f"{c_r} above {c_t}"]
313 if square_j[r] > square_j[t]:
314 s += [f"{c_r} right of {c_t}"]
315 if square_j[r] < square_j[t]:
316 s += [f"{c_r} left of {c_t}"]
321 ######################################################################
323 # Generates sequences
331 max_nb_properties=10,
336 assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
342 nb_squares = torch.randint(max_nb_squares, (1,)) + 1
343 square_position = torch.randperm(height * width)[:nb_squares]
345 # color 0 is white and reserved for the background
346 square_c = torch.randperm(nb_colors)[:nb_squares] + 1
347 square_i = square_position.div(width, rounding_mode="floor")
348 square_j = square_position % width
350 img = [0] * height * width
351 for k in range(nb_squares):
352 img[square_position[k]] = square_c[k]
354 # generates all the true properties
356 s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
358 if pruner is not None:
359 s = list(filter(pruner, s))
361 # picks at most max_nb_properties at random
363 nb_properties = torch.randint(max_nb_properties, (1,)) + 1
365 " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
367 + " ".join([f"{color_names[n]}" for n in img])
375 ######################################################################
377 # Extracts the image after <img> in descr as a 1x3xHxW tensor
380 def descr2img(descr, n, height, width):
382 if type(descr) == list:
383 return torch.cat([descr2img(d, n, height, width) for d in descr], 0)
386 return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze(
392 return color_tokens[t]
394 return [128, 128, 128]
396 d = descr.split("<img>")
397 d = d[n + 1] if len(d) > n + 1 else ""
398 d = d.strip().split(" ")[: height * width]
399 d = d + ["<unk>"] * (height * width - len(d))
400 d = [token2color(t) for t in d]
401 img = torch.tensor(d).permute(1, 0)
402 img = img.reshape(1, 3, height, width)
407 ######################################################################
409 # Returns all the properties of the image after <img> in descr
412 def descr2properties(descr, height, width):
414 if type(descr) == list:
415 return [descr2properties(d, height, width) for d in descr]
417 d = descr.split("<img>")
418 d = d[-1] if len(d) > 1 else ""
419 d = d.strip().split(" ")[: height * width]
420 if len(d) != height * width:
424 for k, x in enumerate(d):
425 if x != color_names[0]:
426 if x in color_tokens:
431 seen[x] = (color_id[x], k // width, k % width)
433 square_infos = tuple(zip(*seen.values()))
436 square_c = torch.tensor(square_infos[0])
437 square_i = torch.tensor(square_infos[1])
438 square_j = torch.tensor(square_infos[2])
440 square_c = torch.tensor([])
441 square_i = torch.tensor([])
442 square_j = torch.tensor([])
444 s = all_properties(height, width, len(seen), square_i, square_j, square_c)
449 ######################################################################
451 # Returns a triplet composed of (1) the total number of properties
452 # before <img> in descr, (2) the total number of properties the image
453 # after <img> verifies, and (3) the number of properties in (1) not in
457 def nb_properties(descr, height, width, pruner=None):
459 if type(descr) == list:
460 return [nb_properties(d, height, width, pruner) for d in descr]
462 d = descr.split("<img>", 1)
465 d = d[0].strip().split("<sep>")
466 d = [x.strip() for x in d]
468 all_properties = set(descr2properties(descr, height, width))
471 requested_properties = set(d)
473 requested_properties = set(filter(pruner, d))
475 missing_properties = requested_properties - all_properties
477 return (len(requested_properties), len(all_properties), len(missing_properties))
480 ######################################################################
482 if __name__ == "__main__":
484 descr = generate(nb=1, height=12, width=16)
486 print(nb_properties(descr, height=12, width=16))
488 with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
492 img = descr2img(descr, n=0, height=12, width=16)
494 img = F.pad(img, (1, 1, 1, 1), value=64)
496 torchvision.utils.save_image(
498 f"picoclvr_example_{n:02d}.png",
506 start_time = time.perf_counter()
507 descr = generate(nb=1000, height=12, width=16)
508 end_time = time.perf_counter()
509 print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
511 ######################################################################