X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;fp=world.py;h=da7de75bd143e95244812b6666179ff915bd5d1e;hb=a92a5ca00f4277f7a133fa6cfaada2bc1981f524;hp=fa305cf1d4d03408924928e6b029f8f8d6a305cb;hpb=2ac9d1299a84f96228f49fbdac02d5a7017445e5;p=picoclvr.git diff --git a/world.py b/world.py index fa305cf..da7de75 100755 --- a/world.py +++ b/world.py @@ -62,12 +62,20 @@ class SignSTE(nn.Module): return s +def loss_H(binary_logits, h_threshold=1): + p = binary_logits.sigmoid().mean(0) + h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2) + h.clamp_(max=h_threshold) + return h_threshold - h.mean() + + def train_encoder( train_input, test_input, depth=2, dim_hidden=48, nb_bits_per_token=8, + lambda_entropy=0.0, lr_start=1e-3, lr_end=1e-4, nb_epochs=10, @@ -160,6 +168,9 @@ def train_encoder( train_loss = F.cross_entropy(output, input) + if lambda_entropy > 0: + loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + acc_train_loss += train_loss.item() * input.size(0) optimizer.zero_grad() @@ -238,7 +249,7 @@ def scene2tensor(xh, yh, scene, size): ) -def random_scene(): +def random_scene(nb_insert_attempts=3): scene = [] colors = [ ((Box.nb_rgb_levels - 1), 0, 0), @@ -252,7 +263,7 @@ def random_scene(): ), ] - for k in range(10): + for k in range(nb_insert_attempts): wh = torch.rand(2) * 0.2 + 0.2 xy = torch.rand(2) * (1 - wh) c = colors[torch.randint(len(colors), (1,))] @@ -286,14 +297,15 @@ def generate_episode(steps, size=64): xh, yh = tuple(x.item() for x in torch.rand(2)) actions = torch.randint(len(effects), (len(steps),)) - change = False + nb_changes = 0 for s, a in zip(steps, actions): if s: frames.append(scene2tensor(xh, yh, scene, size=size)) - g, dx, dy = effects[a] - if g: + grasp, dx, dy = effects[a] + + if grasp: for b in scene: if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh: x, y = b.x, b.y @@ -310,7 +322,7 @@ def generate_episode(steps, size=64): else: xh += dx yh += dy - change = True + nb_changes += 1 else: x, y = xh, yh xh += dx @@ -318,7 +330,7 @@ def generate_episode(steps, size=64): if xh < 0 or xh > 1 or yh < 0 or yh > 1: xh, yh = x, y - if change: + if nb_changes > len(steps) // 3: break return frames, actions @@ -352,12 +364,21 @@ def create_data_and_processors( steps = [True] + [False] * (nb_steps + 1) + [True] train_input, train_actions = generate_episodes(nb_train_samples, steps) - train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage) + train_input, train_actions = train_input.to(device_storage), train_actions.to( + device_storage + ) test_input, test_actions = generate_episodes(nb_test_samples, steps) - test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage) + test_input, test_actions = test_input.to(device_storage), test_actions.to( + device_storage + ) encoder, quantizer, decoder = train_encoder( - train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device + train_input, + test_input, + lambda_entropy=1.0, + nb_epochs=nb_epochs, + logger=logger, + device=device, ) encoder.train(False) quantizer.train(False) @@ -371,7 +392,7 @@ def create_data_and_processors( seq = [] p = pow2.to(device) for x in input.split(batch_size): - x=x.to(device) + x = x.to(device) z = encoder(x) ze_bool = (quantizer(z) >= 0).long() output = (