X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=picoclvr.py;h=0cd306243fd79ee90728eb16340581aad077914f;hb=e39282eef52a7f5ab6654b999009127569b1b599;hp=cc937af2a10ca89cfb4328c604e790c6866252a6;hpb=bd9e5951a5741f7e3e44fc03379795eff83242d6;p=picoclvr.git diff --git a/picoclvr.py b/picoclvr.py index cc937af..0cd3062 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -5,6 +5,7 @@ # Written by Francois Fleuret +import math import torch, torchvision import torch.nn.functional as F @@ -196,14 +197,17 @@ def generate( nb_colors=5, pruner=None, ): - assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1 descr = [] for n in range(nb): - - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background @@ -241,15 +245,8 @@ def generate( # Extracts the image after in descr as a 1x3xHxW tensor -def descr2img(descr, n, height, width): - - if type(descr) == list: - return torch.cat([descr2img(d, n, height, width) for d in descr], 0) - - if type(n) == list: - return torch.cat([descr2img(descr, k, height, width) for k in n], 0).unsqueeze( - 0 - ) +def descr2img(descr, height, width): + result = [] def token2color(t): try: @@ -257,15 +254,15 @@ def descr2img(descr, n, height, width): except KeyError: return [128, 128, 128] - d = descr.split("") - d = d[n + 1] if len(d) > n + 1 else "" - d = d.strip().split(" ")[: height * width] - d = d + [""] * (height * width - len(d)) - d = [token2color(t) for t in d] - img = torch.tensor(d).permute(1, 0) - img = img.reshape(1, 3, height, width) + for d in descr: + d = d.split("")[1] + d = d.strip().split(" ")[: height * width] + d = d + [""] * (height * width - len(d)) + d = [token2color(t) for t in d] + img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width) + result.append(img) - return img + return torch.cat(result, 0) ###################################################################### @@ -274,7 +271,6 @@ def descr2img(descr, n, height, width): def descr2properties(descr, height, width): - if type(descr) == list: return [descr2properties(d, height, width) for d in descr] @@ -319,7 +315,6 @@ def descr2properties(descr, height, width): def nb_properties(descr, height, width, pruner=None): - if type(descr) == list: return [nb_properties(d, height, width, pruner) for d in descr] @@ -353,7 +348,7 @@ if __name__ == "__main__": for d in descr: f.write(f"{d}\n\n") - img = descr2img(descr, n=0, height=12, width=16) + img = descr2img(descr, height=12, width=16) if img.size(0) == 1: img = F.pad(img, (1, 1, 1, 1), value=64)