Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 1 Mar 2024 18:15:54 +0000 (19:15 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 1 Mar 2024 18:15:54 +0000 (19:15 +0100)
picocrafter.py
tiny_vae.py [new file with mode: 0755]

index 5bd6a48..23d93b2 100755 (executable)
@@ -23,8 +23,8 @@
 # iterations.
 #
 # The environment is a rectangular area with walls "#" dispatched
-# randomly. The agent "@" can perform five actions: move NESW or do
-# not move.
+# randomly. The agent "@" can perform five actions: move "NESW" or be
+# immobile "I".
 #
 # There are monsters "$" moving randomly. The agent gets hit by every
 # monster present in one of the 4 direct neighborhoods at the end of
@@ -39,8 +39,9 @@
 # "B", "C"). The keys and vault can only be used in sequence:
 # initially the agent can move only to free spaces, or to the "a", in
 # which case the key is removed from the environment and the agent now
-# carries it, and can move to free spaces or the "A". When it moves to
-# the "A", it gets a reward, loses the "a", the "A" is removed from
+# carries it, it appears in the inventory at the bottom of the frame,
+# and the agent can now move to free spaces or the "A". When it moves
+# to the "A", it gets a reward, loses the "a", the "A" is removed from
 # the environment, but the agent can now move to the "b", etc. Rewards
 # are 1 for "A" and "B" and 10 for "C".
 
@@ -244,7 +245,7 @@ class PicroCrafterEnvironment:
 
     def action2str(self, n):
         if n >= 0 and n < 5:
-            return "XNESW"[n]
+            return "INESW"[n]
         else:
             return "?"
 
diff --git a/tiny_vae.py b/tiny_vae.py
new file mode 100755 (executable)
index 0000000..bbdbf1a
--- /dev/null
@@ -0,0 +1,240 @@
+#!/usr/bin/env python
+
+# @XREMOTE_HOST: elk.fleuret.org
+# @XREMOTE_EXEC: python
+# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
+# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
+# @XREMOTE_GET: *.png
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+import sys, os, argparse, time, math, itertools
+
+import torch, torchvision
+
+from torch import optim, nn
+from torch.nn import functional as F
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+######################################################################
+
+parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+
+parser.add_argument("--nb_epochs", type=int, default=25)
+
+parser.add_argument("--batch_size", type=int, default=100)
+
+parser.add_argument("--data_dir", type=str, default="./data/")
+
+parser.add_argument("--log_filename", type=str, default="train.log")
+
+parser.add_argument("--latent_dim", type=int, default=32)
+
+parser.add_argument("--nb_channels", type=int, default=128)
+
+parser.add_argument("--no_dkl", action="store_true")
+
+args = parser.parse_args()
+
+log_file = open(args.log_filename, "w")
+
+######################################################################
+
+
+def log_string(s):
+    t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime())
+
+    if log_file is not None:
+        log_file.write(t + s + "\n")
+        log_file.flush()
+
+    print(t + s)
+    sys.stdout.flush()
+
+
+######################################################################
+
+
+def sample_gaussian(mu, log_var):
+    std = log_var.mul(0.5).exp()
+    return torch.randn(mu.size(), device=mu.device) * std + mu
+
+
+def log_p_gaussian(x, mu, log_var):
+    var = log_var.exp()
+    return (
+        (-0.5 * ((x - mu).pow(2) / var) - 0.5 * log_var - 0.5 * math.log(2 * math.pi))
+        .flatten(1)
+        .sum(1)
+    )
+
+
+def dkl_gaussians(mu_a, log_var_a, mu_b, log_var_b):
+    mu_a, log_var_a = mu_a.flatten(1), log_var_a.flatten(1)
+    mu_b, log_var_b = mu_b.flatten(1), log_var_b.flatten(1)
+    var_a = log_var_a.exp()
+    var_b = log_var_b.exp()
+    return 0.5 * (
+        log_var_b - log_var_a - 1 + (mu_a - mu_b).pow(2) / var_b + var_a / var_b
+    ).sum(1)
+
+
+######################################################################
+
+
+class LatentGivenImageNet(nn.Module):
+    def __init__(self, nb_channels, latent_dim):
+        super().__init__()
+
+        self.model = nn.Sequential(
+            nn.Conv2d(1, nb_channels, kernel_size=1),  # to 28x28
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 24x24
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=5),  # to 20x20
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2),  # to 9x9
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2),  # to 4x4
+            nn.ReLU(inplace=True),
+            nn.Conv2d(nb_channels, 2 * latent_dim, kernel_size=4),
+        )
+
+    def forward(self, x):
+        output = self.model(x).view(x.size(0), 2, -1)
+        mu, log_var = output[:, 0], output[:, 1]
+        return mu, log_var
+
+
+class ImageGivenLatentNet(nn.Module):
+    def __init__(self, nb_channels, latent_dim):
+        super().__init__()
+
+        self.model = nn.Sequential(
+            nn.ConvTranspose2d(latent_dim, nb_channels, kernel_size=4),
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=3, stride=2
+            ),  # from 4x4
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(
+                nb_channels, nb_channels, kernel_size=4, stride=2
+            ),  # from 9x9
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5),  # from 20x20
+            nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(nb_channels, 2, kernel_size=5),  # from 24x24
+        )
+
+    def forward(self, z):
+        output = self.model(z.view(z.size(0), -1, 1, 1))
+        mu, log_var = output[:, 0:1], output[:, 1:2]
+        return mu, log_var
+
+
+######################################################################
+
+data_dir = os.path.join(args.data_dir, "mnist")
+
+train_set = torchvision.datasets.MNIST(data_dir, train=True, download=True)
+train_input = train_set.data.view(-1, 1, 28, 28).float()
+
+test_set = torchvision.datasets.MNIST(data_dir, train=False, download=True)
+test_input = test_set.data.view(-1, 1, 28, 28).float()
+
+######################################################################
+
+model_q_Z_given_x = LatentGivenImageNet(
+    nb_channels=args.nb_channels, latent_dim=args.latent_dim
+)
+
+model_p_X_given_z = ImageGivenLatentNet(
+    nb_channels=args.nb_channels, latent_dim=args.latent_dim
+)
+
+optimizer = optim.Adam(
+    itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
+    lr=4e-4,
+)
+
+model_p_X_given_z.to(device)
+model_q_Z_given_x.to(device)
+
+######################################################################
+
+train_input, test_input = train_input.to(device), test_input.to(device)
+
+train_mu, train_std = train_input.mean(), train_input.std()
+train_input.sub_(train_mu).div_(train_std)
+test_input.sub_(train_mu).div_(train_std)
+
+######################################################################
+
+mu_p_Z = train_input.new_zeros(1, args.latent_dim)
+log_var_p_Z = mu_p_Z
+
+for epoch in range(args.nb_epochs):
+    acc_loss = 0
+
+    for x in train_input.split(args.batch_size):
+        mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+        z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
+        mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+
+        if args.no_dkl:
+            log_q_z_given_x = log_p_gaussian(z, mu_q_Z_given_x, log_var_q_Z_given_x)
+            log_p_x_z = log_p_gaussian(
+                x, mu_p_X_given_z, log_var_p_X_given_z
+            ) + log_p_gaussian(z, mu_p_Z, log_var_p_Z)
+            loss = -(log_p_x_z - log_q_z_given_x).mean()
+        else:
+            log_p_x_given_z = log_p_gaussian(x, mu_p_X_given_z, log_var_p_X_given_z)
+            dkl_q_Z_given_x_from_p_Z = dkl_gaussians(
+                mu_q_Z_given_x, log_var_q_Z_given_x, mu_p_Z, log_var_p_Z
+            )
+            loss = (-log_p_x_given_z + dkl_q_Z_given_x_from_p_Z).mean()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        acc_loss += loss.item() * x.size(0)
+
+    log_string(f"acc_loss {epoch} {acc_loss/train_input.size(0)}")
+
+######################################################################
+
+
+def save_image(x, filename):
+    x = x * train_std + train_mu
+    x = x.clamp(min=0, max=255) / 255
+    torchvision.utils.save_image(1 - x, filename, nrow=16, pad_value=0.8)
+
+
+# Save a bunch of test images
+
+x = test_input[:256]
+save_image(x, "input.png")
+
+# Save the same images after encoding / decoding
+
+mu_q_Z_given_x, log_var_q_Z_given_x = model_q_Z_given_x(x)
+z = sample_gaussian(mu_q_Z_given_x, log_var_q_Z_given_x)
+mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
+save_image(x, "output.png")
+
+# Generate a bunch of images
+
+z = sample_gaussian(mu_p_Z.expand(x.size(0), -1), log_var_p_Z.expand(x.size(0), -1))
+mu_p_X_given_z, log_var_p_X_given_z = model_p_X_given_z(z)
+x = sample_gaussian(mu_p_X_given_z, log_var_p_X_given_z)
+save_image(x, "synth.png")
+
+######################################################################