def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
- return x
+ return x * 1e-3
######################################################################
)
+def save_logit_image(logit_image, filename):
+ image = (
+ logit_image.softmax(dim=1)
+ * torch.arange(logit_image.size(1), device=logit_image.device)[
+ None, :, None, None
+ ]
+ ).sum(dim=1, keepdim=True)
+ image = image / 255
+
+ torchvision.utils.save_image(1 - image, filename, nrow=16, pad_value=0.8)
+
+
for n_epoch in range(args.nb_epochs):
acc_loss = 0
for targets in pb(train_input.split(args.batch_size), "train"):
- targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
+ targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2)
+ targets = (targets * 0.9 + 0.1 / targets.size(1)).log()
input = torch.randn(targets.size(), device=targets.device) * 1e-3
- # print(f"-----------------")
-
loss = 0
for n in range(nb_iterations):
+ input = input.log_softmax(dim=1)
output = model(input).clamp(min=-10, max=10)
- # assert not output.isnan().any()
current_kl = kl(output, targets)
- # assert not current_kl.isnan().any()
-
nb_remain = nb_iterations - n
tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain
- # assert not tolerated_kl.isnan().any()
-
with torch.no_grad():
a = input.new_full((input.size(0), 1, 1, 1), 0.0)
b = input.new_full((input.size(0), 1, 1, 1), 1.0)
# KL(a * output + (1-a) * targets) = 0
# KL(b * output + (1-b) * targets) >> 0
- for _ in range(10):
+ for _ in range(20):
c = (a + b) / 2
- kl_c = kl(c * output + (1 - c) * targets, targets)
- # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}")
- # print(f"{kl_c}")
- # assert kl_c.min() >= 0
- # assert not kl_c.isnan().any()
+ kl_c = kl((1 - c) * output + c * targets, targets)
m = (kl_c >= tolerated_kl).long()
- a = m * a + (1 - m) * c
- b = m * c + (1 - m) * b
- # print(f"{a.size()=} {b.size()=} {m.size()=}")
+ # m=1 -> kl > tolerated => a = c
+ # m=0 -> kl < tolerated => b = c
+ a = m * c + (1 - m) * a
+ b = m * b + (1 - m) * c
c = (a + b) / 2
- # print()
- # print(tolerated_kl.flatten())
- # print(kl_c.flatten())
- # print(c.flatten())
+ # print((kl_c / (tolerated_kl+1e-6)).flatten())
- input = c * output + (1 - c) * targets
+ input = (1 - c) * output + c * targets
loss += kl(output, input).mean()
- # assert not loss.isnan()
input = input.detach()
optimizer.zero_grad()
log_string(f"acc_loss {n_epoch} {acc_loss}")
+ save_logit_image(output, f"train_output_{n_epoch:04d}.png")
+
######################################################################
- targets = test_input[:256]
- targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0
- input = torch.randn(targets.size(), device=targets.device) * 1e-3
model.eval()
- input = torch.randn(input.size(), device=input.device)
+ input = torch.randn((256, 256, 28, 28), device=targets.device) * 1e-3
+ input = input.log_softmax(dim=1)
+
for _ in range(nb_iterations):
- output = model(input)
+ input = input.log_softmax(dim=1)
+ output = model(input).clamp(min=-10, max=10)
input = output.detach()
- output = (
- output.softmax(dim=1)
- * torch.arange(output.size(1), device=output.device)[None, :, None, None]
- ).sum(dim=1)
- output = output[:, None, :, :] / 255
-
- torchvision.utils.save_image(
- 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8
- )
-
+ save_logit_image(output, f"test_output_{n_epoch:04d}.png")
######################################################################
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+# This code implement a simple system to manipulate formal
+# specifications of tokens on a grid.
+
+import math, re
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+######################################################################
+
+
+class FormalGrid:
+ def __init__(
+ self, grid_height=8, grid_width=8, nb_symbols=4, device=torch.device("cpu")
+ ):
+ self.grid_height = grid_height
+ self.grid_width = grid_width
+ self.nb_symbols = nb_symbols
+ self.nb_configs = (self.grid_height * self.grid_width) ** self.nb_symbols
+
+ self.row = torch.empty(
+ self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+ )
+ self.col = torch.empty(
+ self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device
+ )
+
+ i = torch.arange(self.nb_configs, device=device)
+
+ k = 1
+ for s in range(self.nb_symbols):
+ self.row[:, s] = (i // k) % self.grid_height
+ k *= self.grid_height
+ self.col[:, s] = (i // k) % self.grid_width
+ k *= self.grid_width
+
+ nb = 0
+ for s in range(self.nb_symbols):
+ nb += F.one_hot(
+ (self.row[:, s] * self.grid_width + self.col[:, s]).long(),
+ num_classes=self.grid_height * self.grid_width,
+ ).to(torch.uint8)
+ self.master_grid_set = nb.max(dim=1).values <= 1
+
+ def new_grid_set(self, constraints=None):
+ g = self.master_grid_set.clone()
+ if constraints:
+ self.add_constraints(g, constraints)
+ return g
+
+ ######################################################################
+
+ def constraint_to_fun(self, constraint):
+ g = [None]
+
+ def match(pattern):
+ r = re.search("^" + pattern + "$", constraint)
+ if r:
+ g[0] = (int(x) - 1 for x in r.groups())
+ return True
+ else:
+ return False
+
+ if match("([1-9]) top"):
+ (a,) = g[0]
+ return self.row[:, a] < self.grid_height // 4
+ elif match("([1-9]) bottom"):
+ (a,) = g[0]
+ return self.row[:, a] >= (self.grid_height * 3) // 4
+ elif match("([1-9]) left"):
+ (a,) = g[0]
+ return self.col[:, a] < self.grid_width // 4
+ elif match("([1-9]) right"):
+ (a,) = g[0]
+ return self.col[:, a] >= (self.grid_width * 3) // 4
+ elif match("([1-9]) next_to ([1-9])"):
+ a, b = g[0]
+ return (self.row[:, a] - self.row[:, b]).abs() + (
+ self.col[:, a] - self.col[:, b]
+ ).abs() <= 1
+ elif match("([1-9]) below_of ([1-9])"):
+ a, b = g[0]
+ return self.row[:, a] > self.row[:, b]
+ elif match("([1-9]) above ([1-9])"):
+ a, b = g[0]
+ return self.row[:, a] < self.row[:, b]
+ elif match("([1-9]) left_of ([1-9])"):
+ a, b = g[0]
+ return self.col[:, a] < self.col[:, b]
+ elif match("([1-9]) right_of ([1-9])"):
+ a, b = g[0]
+ return self.col[:, a] > self.col[:, b]
+ elif match("([1-9]) ([1-9]) diagonal"):
+ a, b = g[0]
+ return (self.col[:, a] - self.col[:, b]).abs() == (
+ self.row[:, a] - self.row[:, b]
+ ).abs()
+ elif match("([1-9]) ([1-9]) vertical"):
+ a, b = g[0]
+ return self.col[:, a] == self.col[:, b]
+ elif match("([1-9]) ([1-9]) horizontal"):
+ a, b = g[0]
+ return self.row[:, a] == self.row[:, b]
+
+ elif match("([1-9]) ([1-9]) ([1-9]) aligned"):
+ a, b, c = g[0]
+ return (self.col[:, a] - self.col[:, b]) * (
+ self.row[:, a] - self.row[:, c]
+ ) - (self.row[:, a] - self.row[:, b]) * (
+ self.col[:, a] - self.col[:, c]
+ ) == 0
+
+ elif match("([1-9]) middle_of ([1-9]) ([1-9])"):
+ a, b, c = g[0]
+ return (
+ grid_set
+ & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b])
+ & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
+ )
+
+ elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"):
+ a, b, c = g[0]
+ return (self.col[:, a] - self.col[:, b]) ** 2 + (
+ self.row[:, a] - self.row[:, b]
+ ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + (
+ self.row[:, a] - self.row[:, c]
+ ) ** 2
+
+ elif match("([1-9]) ([1-9]) ([1-9]) right_angle"):
+ a, b, c = g[0]
+ return (self.col[:, a] - self.col[:, b]) * (
+ self.col[:, c] - self.col[:, b]
+ ) + (self.row[:, a] - self.row[:, b]) * (
+ self.row[:, c] - self.row[:, b]
+ ) == 0
+
+ else:
+ raise ValueError(f"Unknown type of constraint {constraint}")
+
+ ######################################################################
+
+ def check_reasonning(self, steps):
+ grid_set = grid.new_grid_set()
+ for step in steps:
+ if step[0] == "=>":
+ f = self.constraint_to_fun(step[1:])
+ if (grid_set & torch.logical_not(f)).any():
+ return False, step[1:]
+ else:
+ grid_set[...] = grid_set & self.constraint_to_fun(step)
+
+ return True, None
+
+ def add_constraints(self, grid_set, constraints):
+ for constraint in constraints:
+ grid_set[...] = grid_set & self.constraint_to_fun(constraint)
+
+ ######################################################################
+
+ def views(self, grid_set):
+ g = torch.empty(self.grid_height, self.grid_width, dtype=torch.int64)
+ row, col = self.row[grid_set], self.col[grid_set]
+ i = torch.randperm(row.size(0))
+ row, col = row[i], col[i]
+ for r, c in zip(row, col):
+ g.zero_()
+ for k in range(self.nb_symbols):
+ g[r[k], c[k]] = k + 1
+ v = ""
+ for r in g:
+ v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n"
+ yield v
+
+
+######################################################################
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+grid = FormalGrid(device=device)
+
+grid_set = grid.new_grid_set(
+ [
+ "2 top",
+ "4 right",
+ "1 left_of 2",
+ "2 left_of 3",
+ "1 2 4 right_angle",
+ "1 2 3 aligned",
+ # "3 2 diagonal",
+ "2 further_away_from 3 than 4",
+ ],
+)
+
+print(f"There are {grid_set.long().sum().item()} configurations")
+
+for v in grid.views(grid_set):
+ print(v)