Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 3 Oct 2024 17:35:01 +0000 (19:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 3 Oct 2024 17:35:01 +0000 (19:35 +0200)
difdis.py
grid.py [new file with mode: 0755]

index 5ec408f..34fef17 100755 (executable)
--- a/difdis.py
+++ b/difdis.py
@@ -412,7 +412,7 @@ class AutoEncoder(nn.Module):
     def forward(self, x):
         x = self.encoder(x)
         x = self.decoder(x)
-        return x
+        return x * 1e-3
 
 
 ######################################################################
@@ -488,28 +488,35 @@ def pb(e, desc):
     )
 
 
+def save_logit_image(logit_image, filename):
+    image = (
+        logit_image.softmax(dim=1)
+        * torch.arange(logit_image.size(1), device=logit_image.device)[
+            None, :, None, None
+        ]
+    ).sum(dim=1, keepdim=True)
+    image = image / 255
+
+    torchvision.utils.save_image(1 - image, filename, nrow=16, pad_value=0.8)
+
+
 for n_epoch in range(args.nb_epochs):
     acc_loss = 0
 
     for targets in pb(train_input.split(args.batch_size), "train"):
-        targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
+        targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2)
+        targets = (targets * 0.9 + 0.1 / targets.size(1)).log()
         input = torch.randn(targets.size(), device=targets.device) * 1e-3
 
-        # print(f"-----------------")
-
         loss = 0
         for n in range(nb_iterations):
+            input = input.log_softmax(dim=1)
             output = model(input).clamp(min=-10, max=10)
-            # assert not output.isnan().any()
             current_kl = kl(output, targets)
 
-            # assert not current_kl.isnan().any()
-
             nb_remain = nb_iterations - n
             tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain
 
-            # assert not tolerated_kl.isnan().any()
-
             with torch.no_grad():
                 a = input.new_full((input.size(0), 1, 1, 1), 0.0)
                 b = input.new_full((input.size(0), 1, 1, 1), 1.0)
@@ -517,28 +524,21 @@ for n_epoch in range(args.nb_epochs):
                 # KL(a * output + (1-a) * targets) = 0
                 # KL(b * output + (1-b) * targets) >> 0
 
-                for _ in range(10):
+                for _ in range(20):
                     c = (a + b) / 2
-                    kl_c = kl(c * output + (1 - c) * targets, targets)
-                    # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}")
-                    # print(f"{kl_c}")
-                    # assert kl_c.min() >= 0
-                    # assert not kl_c.isnan().any()
+                    kl_c = kl((1 - c) * output + c * targets, targets)
                     m = (kl_c >= tolerated_kl).long()
-                    a = m * a + (1 - m) * c
-                    b = m * c + (1 - m) * b
-                    # print(f"{a.size()=} {b.size()=} {m.size()=}")
+                    # m=1 -> kl > tolerated => a = c
+                    # m=0 -> kl < tolerated => b = c
+                    a = m * c + (1 - m) * a
+                    b = m * b + (1 - m) * c
 
                 c = (a + b) / 2
 
-            # print()
-            # print(tolerated_kl.flatten())
-            # print(kl_c.flatten())
-            # print(c.flatten())
+            # print((kl_c / (tolerated_kl+1e-6)).flatten())
 
-            input = c * output + (1 - c) * targets
+            input = (1 - c) * output + c * targets
             loss += kl(output, input).mean()
-            # assert not loss.isnan()
             input = input.detach()
 
         optimizer.zero_grad()
@@ -549,27 +549,20 @@ for n_epoch in range(args.nb_epochs):
 
     log_string(f"acc_loss {n_epoch} {acc_loss}")
 
+    save_logit_image(output, f"train_output_{n_epoch:04d}.png")
+
     ######################################################################
 
-    targets = test_input[:256]
-    targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
-    input = torch.randn(targets.size(), device=targets.device) * 1e-3
     model.eval()
 
-    input = torch.randn(input.size(), device=input.device)
+    input = torch.randn((256, 256, 28, 28), device=targets.device) * 1e-3
+    input = input.log_softmax(dim=1)
+
     for _ in range(nb_iterations):
-        output = model(input)
+        input = input.log_softmax(dim=1)
+        output = model(input).clamp(min=-10, max=10)
         input = output.detach()
 
-    output = (
-        output.softmax(dim=1)
-        * torch.arange(output.size(1), device=output.device)[None, :, None, None]
-    ).sum(dim=1)
-    output = output[:, None, :, :] / 255
-
-    torchvision.utils.save_image(
-        1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8
-    )
-
+    save_logit_image(output, f"test_output_{n_epoch:04d}.png")
 
 ######################################################################
diff --git a/grid.py b/grid.py
new file mode 100755 (executable)
index 0000000..fcf741b
--- /dev/null
+++ b/grid.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# This code implement a simple system to manipulate formal
+# specifications of tokens on a grid.
+
+import math, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class FormalGrid:
+    def __init__(
+        self, grid_height=8, grid_width=8, nb_symbols=4, device=torch.device("cpu")
+    ):
+        self.grid_height = grid_height
+        self.grid_width = grid_width
+        self.nb_symbols = nb_symbols
+        self.nb_configs = (self.grid_height * self.grid_width) ** self.nb_symbols
+
+        self.row = torch.empty(
+            self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+        )
+        self.col = torch.empty(
+            self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+        )
+
+        i = torch.arange(self.nb_configs, device=device)
+
+        k = 1
+        for s in range(self.nb_symbols):
+            self.row[:, s] = (i // k) % self.grid_height
+            k *= self.grid_height
+            self.col[:, s] = (i // k) % self.grid_width
+            k *= self.grid_width
+
+        nb = 0
+        for s in range(self.nb_symbols):
+            nb += F.one_hot(
+                (self.row[:, s] * self.grid_width + self.col[:, s]).long(),
+                num_classes=self.grid_height * self.grid_width,
+            ).to(torch.uint8)
+        self.master_grid_set = nb.max(dim=1).values <= 1
+
+    def new_grid_set(self, constraints=None):
+        g = self.master_grid_set.clone()
+        if constraints:
+            self.add_constraints(g, constraints)
+        return g
+
+    ######################################################################
+
+    def constraint_to_fun(self, constraint):
+        g = [None]
+
+        def match(pattern):
+            r = re.search("^" + pattern + "$", constraint)
+            if r:
+                g[0] = (int(x) - 1 for x in r.groups())
+                return True
+            else:
+                return False
+
+        if match("([1-9]) top"):
+            (a,) = g[0]
+            return self.row[:, a] < self.grid_height // 4
+        elif match("([1-9]) bottom"):
+            (a,) = g[0]
+            return self.row[:, a] >= (self.grid_height * 3) // 4
+        elif match("([1-9]) left"):
+            (a,) = g[0]
+            return self.col[:, a] < self.grid_width // 4
+        elif match("([1-9]) right"):
+            (a,) = g[0]
+            return self.col[:, a] >= (self.grid_width * 3) // 4
+        elif match("([1-9]) next_to ([1-9])"):
+            a, b = g[0]
+            return (self.row[:, a] - self.row[:, b]).abs() + (
+                self.col[:, a] - self.col[:, b]
+            ).abs() <= 1
+        elif match("([1-9]) below_of ([1-9])"):
+            a, b = g[0]
+            return self.row[:, a] > self.row[:, b]
+        elif match("([1-9]) above ([1-9])"):
+            a, b = g[0]
+            return self.row[:, a] < self.row[:, b]
+        elif match("([1-9]) left_of ([1-9])"):
+            a, b = g[0]
+            return self.col[:, a] < self.col[:, b]
+        elif match("([1-9]) right_of ([1-9])"):
+            a, b = g[0]
+            return self.col[:, a] > self.col[:, b]
+        elif match("([1-9]) ([1-9]) diagonal"):
+            a, b = g[0]
+            return (self.col[:, a] - self.col[:, b]).abs() == (
+                self.row[:, a] - self.row[:, b]
+            ).abs()
+        elif match("([1-9]) ([1-9]) vertical"):
+            a, b = g[0]
+            return self.col[:, a] == self.col[:, b]
+        elif match("([1-9]) ([1-9]) horizontal"):
+            a, b = g[0]
+            return self.row[:, a] == self.row[:, b]
+
+        elif match("([1-9]) ([1-9]) ([1-9]) aligned"):
+            a, b, c = g[0]
+            return (self.col[:, a] - self.col[:, b]) * (
+                self.row[:, a] - self.row[:, c]
+            ) - (self.row[:, a] - self.row[:, b]) * (
+                self.col[:, a] - self.col[:, c]
+            ) == 0
+
+        elif match("([1-9]) middle_of ([1-9]) ([1-9])"):
+            a, b, c = g[0]
+            return (
+                grid_set
+                & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b])
+                & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
+            )
+
+        elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"):
+            a, b, c = g[0]
+            return (self.col[:, a] - self.col[:, b]) ** 2 + (
+                self.row[:, a] - self.row[:, b]
+            ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + (
+                self.row[:, a] - self.row[:, c]
+            ) ** 2
+
+        elif match("([1-9]) ([1-9]) ([1-9]) right_angle"):
+            a, b, c = g[0]
+            return (self.col[:, a] - self.col[:, b]) * (
+                self.col[:, c] - self.col[:, b]
+            ) + (self.row[:, a] - self.row[:, b]) * (
+                self.row[:, c] - self.row[:, b]
+            ) == 0
+
+        else:
+            raise ValueError(f"Unknown type of constraint {constraint}")
+
+    ######################################################################
+
+    def check_reasonning(self, steps):
+        grid_set = grid.new_grid_set()
+        for step in steps:
+            if step[0] == "=>":
+                f = self.constraint_to_fun(step[1:])
+                if (grid_set & torch.logical_not(f)).any():
+                    return False, step[1:]
+            else:
+                grid_set[...] = grid_set & self.constraint_to_fun(step)
+
+        return True, None
+
+    def add_constraints(self, grid_set, constraints):
+        for constraint in constraints:
+            grid_set[...] = grid_set & self.constraint_to_fun(constraint)
+
+    ######################################################################
+
+    def views(self, grid_set):
+        g = torch.empty(self.grid_height, self.grid_width, dtype=torch.int64)
+        row, col = self.row[grid_set], self.col[grid_set]
+        i = torch.randperm(row.size(0))
+        row, col = row[i], col[i]
+        for r, c in zip(row, col):
+            g.zero_()
+            for k in range(self.nb_symbols):
+                g[r[k], c[k]] = k + 1
+            v = ""
+            for r in g:
+                v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n"
+            yield v
+
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+grid = FormalGrid(device=device)
+
+grid_set = grid.new_grid_set(
+    [
+        "2 top",
+        "4 right",
+        "1 left_of 2",
+        "2 left_of 3",
+        "1 2 4 right_angle",
+        "1 2 3 aligned",
+        # "3 2 diagonal",
+        "2 further_away_from 3 than 4",
+    ],
+)
+
+print(f"There are {grid_set.long().sum().item()} configurations")
+
+for v in grid.views(grid_set):
+    print(v)