From ca5136e203f3dc9537aadb4071b786e34f1d7f39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 3 Oct 2024 19:35:01 +0200 Subject: [PATCH] Update. --- difdis.py | 71 +++++++++---------- grid.py | 206 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 39 deletions(-) create mode 100755 grid.py diff --git a/difdis.py b/difdis.py index 5ec408f..34fef17 100755 --- a/difdis.py +++ b/difdis.py @@ -412,7 +412,7 @@ class AutoEncoder(nn.Module): def forward(self, x): x = self.encoder(x) x = self.decoder(x) - return x + return x * 1e-3 ###################################################################### @@ -488,28 +488,35 @@ def pb(e, desc): ) +def save_logit_image(logit_image, filename): + image = ( + logit_image.softmax(dim=1) + * torch.arange(logit_image.size(1), device=logit_image.device)[ + None, :, None, None + ] + ).sum(dim=1, keepdim=True) + image = image / 255 + + torchvision.utils.save_image(1 - image, filename, nrow=16, pad_value=0.8) + + for n_epoch in range(args.nb_epochs): acc_loss = 0 for targets in pb(train_input.split(args.batch_size), "train"): - targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) + targets = (targets * 0.9 + 0.1 / targets.size(1)).log() input = torch.randn(targets.size(), device=targets.device) * 1e-3 - # print(f"-----------------") - loss = 0 for n in range(nb_iterations): + input = input.log_softmax(dim=1) output = model(input).clamp(min=-10, max=10) - # assert not output.isnan().any() current_kl = kl(output, targets) - # assert not current_kl.isnan().any() - nb_remain = nb_iterations - n tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain - # assert not tolerated_kl.isnan().any() - with torch.no_grad(): a = input.new_full((input.size(0), 1, 1, 1), 0.0) b = input.new_full((input.size(0), 1, 1, 1), 1.0) @@ -517,28 +524,21 @@ for n_epoch in range(args.nb_epochs): # KL(a * output + (1-a) * targets) = 0 # KL(b * output + (1-b) * targets) >> 0 - for _ in range(10): + for _ in range(20): c = (a + b) / 2 - kl_c = kl(c * output + (1 - c) * targets, targets) - # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}") - # print(f"{kl_c}") - # assert kl_c.min() >= 0 - # assert not kl_c.isnan().any() + kl_c = kl((1 - c) * output + c * targets, targets) m = (kl_c >= tolerated_kl).long() - a = m * a + (1 - m) * c - b = m * c + (1 - m) * b - # print(f"{a.size()=} {b.size()=} {m.size()=}") + # m=1 -> kl > tolerated => a = c + # m=0 -> kl < tolerated => b = c + a = m * c + (1 - m) * a + b = m * b + (1 - m) * c c = (a + b) / 2 - # print() - # print(tolerated_kl.flatten()) - # print(kl_c.flatten()) - # print(c.flatten()) + # print((kl_c / (tolerated_kl+1e-6)).flatten()) - input = c * output + (1 - c) * targets + input = (1 - c) * output + c * targets loss += kl(output, input).mean() - # assert not loss.isnan() input = input.detach() optimizer.zero_grad() @@ -549,27 +549,20 @@ for n_epoch in range(args.nb_epochs): log_string(f"acc_loss {n_epoch} {acc_loss}") + save_logit_image(output, f"train_output_{n_epoch:04d}.png") + ###################################################################### - targets = test_input[:256] - targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 - input = torch.randn(targets.size(), device=targets.device) * 1e-3 model.eval() - input = torch.randn(input.size(), device=input.device) + input = torch.randn((256, 256, 28, 28), device=targets.device) * 1e-3 + input = input.log_softmax(dim=1) + for _ in range(nb_iterations): - output = model(input) + input = input.log_softmax(dim=1) + output = model(input).clamp(min=-10, max=10) input = output.detach() - output = ( - output.softmax(dim=1) - * torch.arange(output.size(1), device=output.device)[None, :, None, None] - ).sum(dim=1) - output = output[:, None, :, :] / 255 - - torchvision.utils.save_image( - 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 - ) - + save_logit_image(output, f"test_output_{n_epoch:04d}.png") ###################################################################### diff --git a/grid.py b/grid.py new file mode 100755 index 0000000..fcf741b --- /dev/null +++ b/grid.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# This code implement a simple system to manipulate formal +# specifications of tokens on a grid. + +import math, re + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +class FormalGrid: + def __init__( + self, grid_height=8, grid_width=8, nb_symbols=4, device=torch.device("cpu") + ): + self.grid_height = grid_height + self.grid_width = grid_width + self.nb_symbols = nb_symbols + self.nb_configs = (self.grid_height * self.grid_width) ** self.nb_symbols + + self.row = torch.empty( + self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device + ) + self.col = torch.empty( + self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device + ) + + i = torch.arange(self.nb_configs, device=device) + + k = 1 + for s in range(self.nb_symbols): + self.row[:, s] = (i // k) % self.grid_height + k *= self.grid_height + self.col[:, s] = (i // k) % self.grid_width + k *= self.grid_width + + nb = 0 + for s in range(self.nb_symbols): + nb += F.one_hot( + (self.row[:, s] * self.grid_width + self.col[:, s]).long(), + num_classes=self.grid_height * self.grid_width, + ).to(torch.uint8) + self.master_grid_set = nb.max(dim=1).values <= 1 + + def new_grid_set(self, constraints=None): + g = self.master_grid_set.clone() + if constraints: + self.add_constraints(g, constraints) + return g + + ###################################################################### + + def constraint_to_fun(self, constraint): + g = [None] + + def match(pattern): + r = re.search("^" + pattern + "$", constraint) + if r: + g[0] = (int(x) - 1 for x in r.groups()) + return True + else: + return False + + if match("([1-9]) top"): + (a,) = g[0] + return self.row[:, a] < self.grid_height // 4 + elif match("([1-9]) bottom"): + (a,) = g[0] + return self.row[:, a] >= (self.grid_height * 3) // 4 + elif match("([1-9]) left"): + (a,) = g[0] + return self.col[:, a] < self.grid_width // 4 + elif match("([1-9]) right"): + (a,) = g[0] + return self.col[:, a] >= (self.grid_width * 3) // 4 + elif match("([1-9]) next_to ([1-9])"): + a, b = g[0] + return (self.row[:, a] - self.row[:, b]).abs() + ( + self.col[:, a] - self.col[:, b] + ).abs() <= 1 + elif match("([1-9]) below_of ([1-9])"): + a, b = g[0] + return self.row[:, a] > self.row[:, b] + elif match("([1-9]) above ([1-9])"): + a, b = g[0] + return self.row[:, a] < self.row[:, b] + elif match("([1-9]) left_of ([1-9])"): + a, b = g[0] + return self.col[:, a] < self.col[:, b] + elif match("([1-9]) right_of ([1-9])"): + a, b = g[0] + return self.col[:, a] > self.col[:, b] + elif match("([1-9]) ([1-9]) diagonal"): + a, b = g[0] + return (self.col[:, a] - self.col[:, b]).abs() == ( + self.row[:, a] - self.row[:, b] + ).abs() + elif match("([1-9]) ([1-9]) vertical"): + a, b = g[0] + return self.col[:, a] == self.col[:, b] + elif match("([1-9]) ([1-9]) horizontal"): + a, b = g[0] + return self.row[:, a] == self.row[:, b] + + elif match("([1-9]) ([1-9]) ([1-9]) aligned"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) * ( + self.row[:, a] - self.row[:, c] + ) - (self.row[:, a] - self.row[:, b]) * ( + self.col[:, a] - self.col[:, c] + ) == 0 + + elif match("([1-9]) middle_of ([1-9]) ([1-9])"): + a, b, c = g[0] + return ( + grid_set + & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b]) + & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b]) + ) + + elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) ** 2 + ( + self.row[:, a] - self.row[:, b] + ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + ( + self.row[:, a] - self.row[:, c] + ) ** 2 + + elif match("([1-9]) ([1-9]) ([1-9]) right_angle"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) * ( + self.col[:, c] - self.col[:, b] + ) + (self.row[:, a] - self.row[:, b]) * ( + self.row[:, c] - self.row[:, b] + ) == 0 + + else: + raise ValueError(f"Unknown type of constraint {constraint}") + + ###################################################################### + + def check_reasonning(self, steps): + grid_set = grid.new_grid_set() + for step in steps: + if step[0] == "=>": + f = self.constraint_to_fun(step[1:]) + if (grid_set & torch.logical_not(f)).any(): + return False, step[1:] + else: + grid_set[...] = grid_set & self.constraint_to_fun(step) + + return True, None + + def add_constraints(self, grid_set, constraints): + for constraint in constraints: + grid_set[...] = grid_set & self.constraint_to_fun(constraint) + + ###################################################################### + + def views(self, grid_set): + g = torch.empty(self.grid_height, self.grid_width, dtype=torch.int64) + row, col = self.row[grid_set], self.col[grid_set] + i = torch.randperm(row.size(0)) + row, col = row[i], col[i] + for r, c in zip(row, col): + g.zero_() + for k in range(self.nb_symbols): + g[r[k], c[k]] = k + 1 + v = "" + for r in g: + v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n" + yield v + + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +grid = FormalGrid(device=device) + +grid_set = grid.new_grid_set( + [ + "2 top", + "4 right", + "1 left_of 2", + "2 left_of 3", + "1 2 4 right_angle", + "1 2 3 aligned", + # "3 2 diagonal", + "2 further_away_from 3 than 4", + ], +) + +print(f"There are {grid_set.long().sum().item()} configurations") + +for v in grid.views(grid_set): + print(v) -- 2.39.5