From: François Fleuret Date: Wed, 26 Jul 2023 21:32:24 +0000 (-1000) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=6045e9a7dd61f0dab60bd1c6ff71f6bd5c32778b;p=picoclvr.git Update. --- diff --git a/do_all.sh b/do_all.sh new file mode 100755 index 0000000..76f1982 --- /dev/null +++ b/do_all.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +################################################################## +# START_IP_HEADER # +# # +# Written by Francois Fleuret # +# Contact for comments & bug reports # +# # +# END_IP_HEADER # +################################################################## + +# set -e +# set -o pipefail + +#prefix="--nb_train_samples=1000 --nb_test_samples=100 --batch_size=25 --nb_epochs=2 --max_percents_of_test_in_train=-1 --model=17K" +prefix="--nb_epochs=2" + +for task in byheart learnop guessop twotargets addition picoclvr maze snake stack expr rpl +do + [[ ! -d results_${task} ]] && ./main.py ${prefix} --task=${task} +done + diff --git a/picoclvr.py b/picoclvr.py index 5da3943..0cd3062 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -5,6 +5,7 @@ # Written by Francois Fleuret +import math import torch, torchvision import torch.nn.functional as F @@ -201,7 +202,12 @@ def generate( descr = [] for n in range(nb): - nb_squares = torch.randint(max_nb_squares, (1,)) + 1 + # we want uniform over the combinations of 1 to max_nb_squares + # pixels of nb_colors + logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float() + dist = torch.distributions.categorical.Categorical(logits=logits) + nb_squares = dist.sample((1,)) + 1 + # nb_squares = torch.randint(max_nb_squares, (1,)) + 1 square_position = torch.randperm(height * width)[:nb_squares] # color 0 is white and reserved for the background diff --git a/problems.py b/problems.py index 5686404..2c8602c 100755 --- a/problems.py +++ b/problems.py @@ -87,7 +87,7 @@ class ProblemByHeart(Problem): class ProblemLearnOperator(Problem): - def __init__(self, nb_operators=100, len_source=5, len_result=8): + def __init__(self, nb_operators=100, len_source=6, len_result=9): self.len_source = len_source self.len_result = len_result self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1