Update.
[picoclvr.git] / world.py
index fa305cf..da7de75 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -62,12 +62,20 @@ class SignSTE(nn.Module):
             return s
 
 
+def loss_H(binary_logits, h_threshold=1):
+    p = binary_logits.sigmoid().mean(0)
+    h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+    h.clamp_(max=h_threshold)
+    return h_threshold - h.mean()
+
+
 def train_encoder(
     train_input,
     test_input,
     depth=2,
     dim_hidden=48,
     nb_bits_per_token=8,
+    lambda_entropy=0.0,
     lr_start=1e-3,
     lr_end=1e-4,
     nb_epochs=10,
@@ -160,6 +168,9 @@ def train_encoder(
 
             train_loss = F.cross_entropy(output, input)
 
+            if lambda_entropy > 0:
+                loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
             acc_train_loss += train_loss.item() * input.size(0)
 
             optimizer.zero_grad()
@@ -238,7 +249,7 @@ def scene2tensor(xh, yh, scene, size):
     )
 
 
-def random_scene():
+def random_scene(nb_insert_attempts=3):
     scene = []
     colors = [
         ((Box.nb_rgb_levels - 1), 0, 0),
@@ -252,7 +263,7 @@ def random_scene():
         ),
     ]
 
-    for k in range(10):
+    for k in range(nb_insert_attempts):
         wh = torch.rand(2) * 0.2 + 0.2
         xy = torch.rand(2) * (1 - wh)
         c = colors[torch.randint(len(colors), (1,))]
@@ -286,14 +297,15 @@ def generate_episode(steps, size=64):
         xh, yh = tuple(x.item() for x in torch.rand(2))
 
         actions = torch.randint(len(effects), (len(steps),))
-        change = False
+        nb_changes = 0
 
         for s, a in zip(steps, actions):
             if s:
                 frames.append(scene2tensor(xh, yh, scene, size=size))
 
-            g, dx, dy = effects[a]
-            if g:
+            grasp, dx, dy = effects[a]
+
+            if grasp:
                 for b in scene:
                     if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
                         x, y = b.x, b.y
@@ -310,7 +322,7 @@ def generate_episode(steps, size=64):
                         else:
                             xh += dx
                             yh += dy
-                            change = True
+                            nb_changes += 1
             else:
                 x, y = xh, yh
                 xh += dx
@@ -318,7 +330,7 @@ def generate_episode(steps, size=64):
                 if xh < 0 or xh > 1 or yh < 0 or yh > 1:
                     xh, yh = x, y
 
-        if change:
+        if nb_changes > len(steps) // 3:
             break
 
     return frames, actions
@@ -352,12 +364,21 @@ def create_data_and_processors(
         steps = [True] + [False] * (nb_steps + 1) + [True]
 
     train_input, train_actions = generate_episodes(nb_train_samples, steps)
-    train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage)
+    train_input, train_actions = train_input.to(device_storage), train_actions.to(
+        device_storage
+    )
     test_input, test_actions = generate_episodes(nb_test_samples, steps)
-    test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage)
+    test_input, test_actions = test_input.to(device_storage), test_actions.to(
+        device_storage
+    )
 
     encoder, quantizer, decoder = train_encoder(
-        train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
+        train_input,
+        test_input,
+        lambda_entropy=1.0,
+        nb_epochs=nb_epochs,
+        logger=logger,
+        device=device,
     )
     encoder.train(False)
     quantizer.train(False)
@@ -371,7 +392,7 @@ def create_data_and_processors(
         seq = []
         p = pow2.to(device)
         for x in input.split(batch_size):
-            x=x.to(device)
+            x = x.to(device)
             z = encoder(x)
             ze_bool = (quantizer(z) >= 0).long()
             output = (