Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 14 Jul 2023 16:08:35 +0000 (18:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 14 Jul 2023 16:08:35 +0000 (18:08 +0200)
world.py

index b077987..61a07e9 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -146,7 +146,7 @@ def train_encoder(
 
         for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
             z = encoder(input)
-            zq = z if k < 1 else quantizer(z)
+            zq = z if k < 2 else quantizer(z)
             output = decoder(zq)
 
             output = output.reshape(
@@ -334,8 +334,6 @@ def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10):
     train_input = generate_episodes(nb_train_samples, steps)
     test_input = generate_episodes(nb_test_samples, steps)
 
-    print(f"{train_input.size()=} {test_input.size()=}")
-
     encoder, quantizer, decoder = train_encoder(
         train_input, test_input, nb_epochs=nb_epochs
     )
@@ -347,26 +345,40 @@ def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10):
     pow2 = (2 ** torch.arange(z.size(1), device=z.device))[None, None, :]
     z_h, z_w = z.size(2), z.size(3)
 
-    def frame2seq(x):
-        z = encoder(x)
-        ze_bool = (quantizer(z) >= 0).long()
-        seq = (
-            ze_bool.permute(0, 2, 3, 1).reshape(ze_bool.size(0), -1, ze_bool.size(1))
-            * pow2
-        ).sum(-1)
-        return seq
-
-    def seq2frame(seq, T=1e-2):
-        zd_bool = (seq[:, :, None] // pow2) % 2
-        zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
-        logits = decoder(zd_bool * 2.0 - 1.0)
-        logits = logits.reshape(
-            logits.size(0), -1, 3, logits.size(2), logits.size(3)
-        ).permute(0, 2, 3, 4, 1)
-        results = torch.distributions.categorical.Categorical(
-            logits=logits / T
-        ).sample()
-        return results
+    def frame2seq(input, batch_size=25):
+        seq = []
+
+        for x in input.split(batch_size):
+            z = encoder(x)
+            ze_bool = (quantizer(z) >= 0).long()
+            output = (
+                ze_bool.permute(0, 2, 3, 1).reshape(
+                    ze_bool.size(0), -1, ze_bool.size(1)
+                )
+                * pow2
+            ).sum(-1)
+
+            seq.append(output)
+
+        return torch.cat(seq, dim=0)
+
+    def seq2frame(input, batch_size=25, T=1e-2):
+        frames = []
+
+        for seq in input.split(batch_size):
+            zd_bool = (seq[:, :, None] // pow2) % 2
+            zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2)
+            logits = decoder(zd_bool * 2.0 - 1.0)
+            logits = logits.reshape(
+                logits.size(0), -1, 3, logits.size(2), logits.size(3)
+            ).permute(0, 2, 3, 4, 1)
+            output = torch.distributions.categorical.Categorical(
+                logits=logits / T
+            ).sample()
+
+            frames.append(output)
+
+        return torch.cat(frames, dim=0)
 
     return train_input, test_input, frame2seq, seq2frame
 
@@ -375,7 +387,10 @@ def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10):
 
 if __name__ == "__main__":
     train_input, test_input, frame2seq, seq2frame = create_data_and_processors(
-        10000, 1000
+        # 10000, 1000,
+        100,
+        100,
+        nb_epochs=2,
     )
 
     input = test_input[:64]