From 3afcea624963ad2d381c19a7d54bb26e218c5bce Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 13 Jun 2024 20:05:29 +0200 Subject: [PATCH 01/16] Update. --- redshift.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/redshift.py b/redshift.py index b3507ed..2ed1e52 100755 --- a/redshift.py +++ b/redshift.py @@ -9,8 +9,10 @@ from torch.nn import functional as F torch.set_default_dtype(torch.float64) +nb_hidden = 5 +hidden_dim = 100 + res = 256 -nh = 100 input = torch.cat( [ @@ -28,11 +30,10 @@ class Angles(nn.Module): for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]: for s in [1.0, 10.0]: - layers = [nn.Linear(2, nh), activation()] - nb_hidden = 4 - for k in range(nb_hidden): - layers += [nn.Linear(nh, nh), activation()] - layers += [nn.Linear(nh, 2)] + layers = [nn.Linear(2, hidden_dim), activation()] + for k in range(nb_hidden - 1): + layers += [nn.Linear(hidden_dim, hidden_dim), activation()] + layers += [nn.Linear(hidden_dim, 2)] model = nn.Sequential(*layers) with torch.no_grad(): -- 2.39.5 From f730f5fc1003f74ae7afea3451f17ad8925bd909 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 1 Oct 2024 07:12:59 +0200 Subject: [PATCH 02/16] Update. --- tinygen.py | 520 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 520 insertions(+) create mode 100755 tinygen.py diff --git a/tinygen.py b/tinygen.py new file mode 100755 index 0000000..66c005c --- /dev/null +++ b/tinygen.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_GET: *.png + +import sys, argparse, time, math + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F +from tqdm import tqdm + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") + +parser.add_argument("--nb_epochs", type=int, default=250) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--data_dir", type=str, default="./data/") + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--embedding_dim", type=int, default=64) + +parser.add_argument("--nb_channels", type=int, default=64) + +args = parser.parse_args() + +log_file = open(args.log_filename, "w") + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 # works with float, weird + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + y = x + pe + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +def vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Linear(2, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=1) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1) + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x).reshape(-1, 1, 28, 28) + return x + + +###################################################################### + + +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_work_tokens=100, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x + + +###################################################################### + + +class FullAveragePooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1) + return x + + +class ResNetBlock(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.conv1 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + self.conv2 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x): + y = F.relu(self.conv1(x)) + y = F.relu(x + self.conv2(y)) + return y + + +###################################################################### + + +class ResAutoEncoder(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.encoder = nn.Conv2d( + 2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + self.core = nn.Sequential( + *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)] + ) + self.decoder = nn.Conv2d( + nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size=4), + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 1, kernel_size=5), # from 24x24 + ) + + def encode(self, x): + return self.encoder(x).view(x.size(0), -1) + + def decode(self, z): + return self.decoder(z.view(z.size(0), -1, 1, 1)) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + +train_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=True, download=True +) +train_input = train_set.data.view(-1, 1, 28, 28).float() + +test_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=False, download=True +) +test_input = test_set.data.view(-1, 1, 28, 28).float() + +###################################################################### + +model = AutoEncoder(args.nb_channels, args.embedding_dim) + +# model = AttentionAE( +# dim_model=16, +# dim_keys=16, +# dim_hidden=16, +# nb_heads=4, +# nb_blocks=4, +# dropout=0.0, +# len_max=1e5, +# ) + +# model = ResAutoEncoder(nb_channels=128, kernel_size=9) + +print(model) + +optimizer = optim.Adam(model.parameters(), lr=1e-3) + +model.to(device) + +train_input, test_input = train_input.to(device), test_input.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +nb_iterations = 10 + +###################################################################### + + +def dist(u, v): + return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + + +def pb(e, desc): + return tqdm( + e, + dynamic_ncols=True, + desc=desc, + total=train_input.size(0) // args.batch_size, + delay=10, + ) + + +for n_epoch in range(args.nb_epochs): + acc_loss = 0 + + for targets in pb(train_input.split(args.batch_size), "train"): + input = torch.randn(targets.size(), device=targets.device) + + loss = 0 + for n in range(nb_iterations): + output = model(input) + current_d = dist(targets, output) + nb_remain = nb_iterations - n + tolerated_d = dist(targets, input) * (nb_remain - 1) / nb_remain + a = (tolerated_d / (current_d + 1e-6)).clamp(max=1) + loss += (1 - a).mean() / nb_iterations + input = targets - a * (targets - output.detach()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string(f"acc_loss {n_epoch} {acc_loss}") + + ###################################################################### + + input = test_input[:256] + model.eval() + + input = torch.randn(input.size(), device=input.device) + for _ in range(nb_iterations): + output = model(input) + input = output.detach() + + output = (output * std + mu) / 255 + + torchvision.utils.save_image( + 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 + ) + + +###################################################################### -- 2.39.5 From 1399a4f5270eba99e8557125408826f26910b539 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 1 Oct 2024 21:38:51 +0200 Subject: [PATCH 03/16] Update. --- difdis.py | 575 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 575 insertions(+) create mode 100755 difdis.py diff --git a/difdis.py b/difdis.py new file mode 100755 index 0000000..5ec408f --- /dev/null +++ b/difdis.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_GET: *.png + +import sys, argparse, time, math + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F +from tqdm import tqdm + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") + +parser.add_argument("--nb_epochs", type=int, default=250) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--nb_train_samples", type=int, default=-1) + +parser.add_argument("--nb_test_samples", type=int, default=-1) + +parser.add_argument("--data_dir", type=str, default="./data/") + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--embedding_dim", type=int, default=64) + +parser.add_argument("--nb_channels", type=int, default=64) + +args = parser.parse_args() + +log_file = open(args.log_filename, "w") + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 # works with float, weird + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + y = x + pe + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +def vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Linear(2, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=1) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1) + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x).reshape(-1, 1, 28, 28) + return x + + +###################################################################### + + +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_work_tokens=100, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x + + +###################################################################### + + +class FullAveragePooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1) + return x + + +class ResNetBlock(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.conv1 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + self.conv2 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x): + y = F.relu(self.conv1(x)) + y = F.relu(x + self.conv2(y)) + return y + + +###################################################################### + + +class ResAutoEncoder(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.encoder = nn.Conv2d( + 2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + self.core = nn.Sequential( + *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)] + ) + self.decoder = nn.Conv2d( + nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(256, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size=4), + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 256, kernel_size=5), # from 24x24 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + +train_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=True, download=True +) +train_input = train_set.data.view(-1, 28, 28).long() + +test_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=False, download=True +) +test_input = test_set.data.view(-1, 28, 28).long() + +if args.nb_train_samples > 0: + train_input = train_input[: args.nb_train_samples] + +if args.nb_test_samples > 0: + test_input = test_input[: args.nb_test_samples] + +###################################################################### + +model = AutoEncoder(args.nb_channels, args.embedding_dim) + +# model = AttentionAE( +# dim_model=16, +# dim_keys=16, +# dim_hidden=16, +# nb_heads=4, +# nb_blocks=4, +# dropout=0.0, +# len_max=1e5, +# ) + +# model = ResAutoEncoder(nb_channels=128, kernel_size=9) + +print(model) + +optimizer = optim.Adam(model.parameters(), lr=1e-3) + +model.to(device) + +train_input, test_input = train_input.to(device), test_input.to(device) + +nb_iterations = 10 + +###################################################################### + + +def dist(u, v): + return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + + +def kl(u, v): + s = (u.size(0),) + (1,) * (u.dim() - 1) + u = u.flatten(1).log_softmax(dim=1) + v = v.flatten(1).log_softmax(dim=1) + return ( + F.kl_div(u, v, log_target=True, reduction="none") + .sum(dim=1, keepdim=True) + .clamp(min=0) + .reshape(*s) + ) + + +def pb(e, desc): + return tqdm( + e, + dynamic_ncols=True, + desc=desc, + total=train_input.size(0) // args.batch_size, + delay=10, + ) + + +for n_epoch in range(args.nb_epochs): + acc_loss = 0 + + for targets in pb(train_input.split(args.batch_size), "train"): + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + input = torch.randn(targets.size(), device=targets.device) * 1e-3 + + # print(f"-----------------") + + loss = 0 + for n in range(nb_iterations): + output = model(input).clamp(min=-10, max=10) + # assert not output.isnan().any() + current_kl = kl(output, targets) + + # assert not current_kl.isnan().any() + + nb_remain = nb_iterations - n + tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain + + # assert not tolerated_kl.isnan().any() + + with torch.no_grad(): + a = input.new_full((input.size(0), 1, 1, 1), 0.0) + b = input.new_full((input.size(0), 1, 1, 1), 1.0) + + # KL(a * output + (1-a) * targets) = 0 + # KL(b * output + (1-b) * targets) >> 0 + + for _ in range(10): + c = (a + b) / 2 + kl_c = kl(c * output + (1 - c) * targets, targets) + # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}") + # print(f"{kl_c}") + # assert kl_c.min() >= 0 + # assert not kl_c.isnan().any() + m = (kl_c >= tolerated_kl).long() + a = m * a + (1 - m) * c + b = m * c + (1 - m) * b + # print(f"{a.size()=} {b.size()=} {m.size()=}") + + c = (a + b) / 2 + + # print() + # print(tolerated_kl.flatten()) + # print(kl_c.flatten()) + # print(c.flatten()) + + input = c * output + (1 - c) * targets + loss += kl(output, input).mean() + # assert not loss.isnan() + input = input.detach() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string(f"acc_loss {n_epoch} {acc_loss}") + + ###################################################################### + + targets = test_input[:256] + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + input = torch.randn(targets.size(), device=targets.device) * 1e-3 + model.eval() + + input = torch.randn(input.size(), device=input.device) + for _ in range(nb_iterations): + output = model(input) + input = output.detach() + + output = ( + output.softmax(dim=1) + * torch.arange(output.size(1), device=output.device)[None, :, None, None] + ).sum(dim=1) + output = output[:, None, :, :] / 255 + + torchvision.utils.save_image( + 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 + ) + + +###################################################################### -- 2.39.5 From ca5136e203f3dc9537aadb4071b786e34f1d7f39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 3 Oct 2024 19:35:01 +0200 Subject: [PATCH 04/16] Update. --- difdis.py | 71 +++++++++---------- grid.py | 206 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 39 deletions(-) create mode 100755 grid.py diff --git a/difdis.py b/difdis.py index 5ec408f..34fef17 100755 --- a/difdis.py +++ b/difdis.py @@ -412,7 +412,7 @@ class AutoEncoder(nn.Module): def forward(self, x): x = self.encoder(x) x = self.decoder(x) - return x + return x * 1e-3 ###################################################################### @@ -488,28 +488,35 @@ def pb(e, desc): ) +def save_logit_image(logit_image, filename): + image = ( + logit_image.softmax(dim=1) + * torch.arange(logit_image.size(1), device=logit_image.device)[ + None, :, None, None + ] + ).sum(dim=1, keepdim=True) + image = image / 255 + + torchvision.utils.save_image(1 - image, filename, nrow=16, pad_value=0.8) + + for n_epoch in range(args.nb_epochs): acc_loss = 0 for targets in pb(train_input.split(args.batch_size), "train"): - targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 + targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) + targets = (targets * 0.9 + 0.1 / targets.size(1)).log() input = torch.randn(targets.size(), device=targets.device) * 1e-3 - # print(f"-----------------") - loss = 0 for n in range(nb_iterations): + input = input.log_softmax(dim=1) output = model(input).clamp(min=-10, max=10) - # assert not output.isnan().any() current_kl = kl(output, targets) - # assert not current_kl.isnan().any() - nb_remain = nb_iterations - n tolerated_kl = kl(input, targets) * (nb_remain - 1) / nb_remain - # assert not tolerated_kl.isnan().any() - with torch.no_grad(): a = input.new_full((input.size(0), 1, 1, 1), 0.0) b = input.new_full((input.size(0), 1, 1, 1), 1.0) @@ -517,28 +524,21 @@ for n_epoch in range(args.nb_epochs): # KL(a * output + (1-a) * targets) = 0 # KL(b * output + (1-b) * targets) >> 0 - for _ in range(10): + for _ in range(20): c = (a + b) / 2 - kl_c = kl(c * output + (1 - c) * targets, targets) - # print(f"{c.size()=} {output.size()=} {targets.size()=} {kl_c.size()=}") - # print(f"{kl_c}") - # assert kl_c.min() >= 0 - # assert not kl_c.isnan().any() + kl_c = kl((1 - c) * output + c * targets, targets) m = (kl_c >= tolerated_kl).long() - a = m * a + (1 - m) * c - b = m * c + (1 - m) * b - # print(f"{a.size()=} {b.size()=} {m.size()=}") + # m=1 -> kl > tolerated => a = c + # m=0 -> kl < tolerated => b = c + a = m * c + (1 - m) * a + b = m * b + (1 - m) * c c = (a + b) / 2 - # print() - # print(tolerated_kl.flatten()) - # print(kl_c.flatten()) - # print(c.flatten()) + # print((kl_c / (tolerated_kl+1e-6)).flatten()) - input = c * output + (1 - c) * targets + input = (1 - c) * output + c * targets loss += kl(output, input).mean() - # assert not loss.isnan() input = input.detach() optimizer.zero_grad() @@ -549,27 +549,20 @@ for n_epoch in range(args.nb_epochs): log_string(f"acc_loss {n_epoch} {acc_loss}") + save_logit_image(output, f"train_output_{n_epoch:04d}.png") + ###################################################################### - targets = test_input[:256] - targets = F.one_hot(targets, num_classes=256).permute(0, 3, 1, 2) * 9.0 + 1.0 - input = torch.randn(targets.size(), device=targets.device) * 1e-3 model.eval() - input = torch.randn(input.size(), device=input.device) + input = torch.randn((256, 256, 28, 28), device=targets.device) * 1e-3 + input = input.log_softmax(dim=1) + for _ in range(nb_iterations): - output = model(input) + input = input.log_softmax(dim=1) + output = model(input).clamp(min=-10, max=10) input = output.detach() - output = ( - output.softmax(dim=1) - * torch.arange(output.size(1), device=output.device)[None, :, None, None] - ).sum(dim=1) - output = output[:, None, :, :] / 255 - - torchvision.utils.save_image( - 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 - ) - + save_logit_image(output, f"test_output_{n_epoch:04d}.png") ###################################################################### diff --git a/grid.py b/grid.py new file mode 100755 index 0000000..fcf741b --- /dev/null +++ b/grid.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# This code implement a simple system to manipulate formal +# specifications of tokens on a grid. + +import math, re + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + + +class FormalGrid: + def __init__( + self, grid_height=8, grid_width=8, nb_symbols=4, device=torch.device("cpu") + ): + self.grid_height = grid_height + self.grid_width = grid_width + self.nb_symbols = nb_symbols + self.nb_configs = (self.grid_height * self.grid_width) ** self.nb_symbols + + self.row = torch.empty( + self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device + ) + self.col = torch.empty( + self.nb_configs, self.nb_symbols, dtype=torch.int8, device=device + ) + + i = torch.arange(self.nb_configs, device=device) + + k = 1 + for s in range(self.nb_symbols): + self.row[:, s] = (i // k) % self.grid_height + k *= self.grid_height + self.col[:, s] = (i // k) % self.grid_width + k *= self.grid_width + + nb = 0 + for s in range(self.nb_symbols): + nb += F.one_hot( + (self.row[:, s] * self.grid_width + self.col[:, s]).long(), + num_classes=self.grid_height * self.grid_width, + ).to(torch.uint8) + self.master_grid_set = nb.max(dim=1).values <= 1 + + def new_grid_set(self, constraints=None): + g = self.master_grid_set.clone() + if constraints: + self.add_constraints(g, constraints) + return g + + ###################################################################### + + def constraint_to_fun(self, constraint): + g = [None] + + def match(pattern): + r = re.search("^" + pattern + "$", constraint) + if r: + g[0] = (int(x) - 1 for x in r.groups()) + return True + else: + return False + + if match("([1-9]) top"): + (a,) = g[0] + return self.row[:, a] < self.grid_height // 4 + elif match("([1-9]) bottom"): + (a,) = g[0] + return self.row[:, a] >= (self.grid_height * 3) // 4 + elif match("([1-9]) left"): + (a,) = g[0] + return self.col[:, a] < self.grid_width // 4 + elif match("([1-9]) right"): + (a,) = g[0] + return self.col[:, a] >= (self.grid_width * 3) // 4 + elif match("([1-9]) next_to ([1-9])"): + a, b = g[0] + return (self.row[:, a] - self.row[:, b]).abs() + ( + self.col[:, a] - self.col[:, b] + ).abs() <= 1 + elif match("([1-9]) below_of ([1-9])"): + a, b = g[0] + return self.row[:, a] > self.row[:, b] + elif match("([1-9]) above ([1-9])"): + a, b = g[0] + return self.row[:, a] < self.row[:, b] + elif match("([1-9]) left_of ([1-9])"): + a, b = g[0] + return self.col[:, a] < self.col[:, b] + elif match("([1-9]) right_of ([1-9])"): + a, b = g[0] + return self.col[:, a] > self.col[:, b] + elif match("([1-9]) ([1-9]) diagonal"): + a, b = g[0] + return (self.col[:, a] - self.col[:, b]).abs() == ( + self.row[:, a] - self.row[:, b] + ).abs() + elif match("([1-9]) ([1-9]) vertical"): + a, b = g[0] + return self.col[:, a] == self.col[:, b] + elif match("([1-9]) ([1-9]) horizontal"): + a, b = g[0] + return self.row[:, a] == self.row[:, b] + + elif match("([1-9]) ([1-9]) ([1-9]) aligned"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) * ( + self.row[:, a] - self.row[:, c] + ) - (self.row[:, a] - self.row[:, b]) * ( + self.col[:, a] - self.col[:, c] + ) == 0 + + elif match("([1-9]) middle_of ([1-9]) ([1-9])"): + a, b, c = g[0] + return ( + grid_set + & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b]) + & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b]) + ) + + elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) ** 2 + ( + self.row[:, a] - self.row[:, b] + ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + ( + self.row[:, a] - self.row[:, c] + ) ** 2 + + elif match("([1-9]) ([1-9]) ([1-9]) right_angle"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) * ( + self.col[:, c] - self.col[:, b] + ) + (self.row[:, a] - self.row[:, b]) * ( + self.row[:, c] - self.row[:, b] + ) == 0 + + else: + raise ValueError(f"Unknown type of constraint {constraint}") + + ###################################################################### + + def check_reasonning(self, steps): + grid_set = grid.new_grid_set() + for step in steps: + if step[0] == "=>": + f = self.constraint_to_fun(step[1:]) + if (grid_set & torch.logical_not(f)).any(): + return False, step[1:] + else: + grid_set[...] = grid_set & self.constraint_to_fun(step) + + return True, None + + def add_constraints(self, grid_set, constraints): + for constraint in constraints: + grid_set[...] = grid_set & self.constraint_to_fun(constraint) + + ###################################################################### + + def views(self, grid_set): + g = torch.empty(self.grid_height, self.grid_width, dtype=torch.int64) + row, col = self.row[grid_set], self.col[grid_set] + i = torch.randperm(row.size(0)) + row, col = row[i], col[i] + for r, c in zip(row, col): + g.zero_() + for k in range(self.nb_symbols): + g[r[k], c[k]] = k + 1 + v = "" + for r in g: + v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n" + yield v + + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +grid = FormalGrid(device=device) + +grid_set = grid.new_grid_set( + [ + "2 top", + "4 right", + "1 left_of 2", + "2 left_of 3", + "1 2 4 right_angle", + "1 2 3 aligned", + # "3 2 diagonal", + "2 further_away_from 3 than 4", + ], +) + +print(f"There are {grid_set.long().sum().item()} configurations") + +for v in grid.views(grid_set): + print(v) -- 2.39.5 From b43f300c8cdd1896555db7b246e19bdc548ab691 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 3 Oct 2024 19:53:49 +0200 Subject: [PATCH 05/16] Update. --- grid.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/grid.py b/grid.py index fcf741b..bddda6d 100755 --- a/grid.py +++ b/grid.py @@ -189,13 +189,14 @@ grid = FormalGrid(device=device) grid_set = grid.new_grid_set( [ - "2 top", + "4 top", "4 right", + "1 top", + "1 left", "1 left_of 2", "2 left_of 3", "1 2 4 right_angle", "1 2 3 aligned", - # "3 2 diagonal", "2 further_away_from 3 than 4", ], ) -- 2.39.5 From f0a5ee3f57758067e16582d358cc41ed12582c89 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 5 Oct 2024 04:47:36 +0200 Subject: [PATCH 06/16] Update. --- grid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grid.py b/grid.py index bddda6d..08ca225 100755 --- a/grid.py +++ b/grid.py @@ -3,6 +3,7 @@ # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ + # Written by Francois Fleuret # This code implement a simple system to manipulate formal -- 2.39.5 From bc9b6d2abf25569ecdf3b40cafe0090133376a93 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 5 Oct 2024 04:54:41 +0200 Subject: [PATCH 07/16] Test --- grid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grid.py b/grid.py index 08ca225..991c44f 100755 --- a/grid.py +++ b/grid.py @@ -4,6 +4,7 @@ # https://creativecommons.org/publicdomain/zero/1.0/ + # Written by Francois Fleuret # This code implement a simple system to manipulate formal -- 2.39.5 From 03f7803f6770063c07b8000e06cbe6c694cdaf41 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 26 Oct 2024 23:53:32 +0200 Subject: [PATCH 08/16] Update. --- grid.py | 74 +++++++++++++++++++++++++++------------------------- tinymnist.py | 8 +++--- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/grid.py b/grid.py index 991c44f..7168491 100755 --- a/grid.py +++ b/grid.py @@ -4,7 +4,6 @@ # https://creativecommons.org/publicdomain/zero/1.0/ - # Written by Francois Fleuret # This code implement a simple system to manipulate formal @@ -72,36 +71,36 @@ class FormalGrid: else: return False - if match("([1-9]) top"): + if match("([1-9]) is_in_top_half"): (a,) = g[0] - return self.row[:, a] < self.grid_height // 4 - elif match("([1-9]) bottom"): + return self.row[:, a] < self.grid_height // 2 + elif match("([1-9]) is_in_bottom_half"): (a,) = g[0] - return self.row[:, a] >= (self.grid_height * 3) // 4 - elif match("([1-9]) left"): + return self.row[:, a] >= self.grid_height // 2 + elif match("([1-9]) is_on_left_side"): (a,) = g[0] - return self.col[:, a] < self.grid_width // 4 - elif match("([1-9]) right"): + return self.col[:, a] < self.grid_width // 2 + elif match("([1-9]) is_on_right_side"): (a,) = g[0] - return self.col[:, a] >= (self.grid_width * 3) // 4 + return self.col[:, a] >= self.grid_width // 2 elif match("([1-9]) next_to ([1-9])"): a, b = g[0] return (self.row[:, a] - self.row[:, b]).abs() + ( self.col[:, a] - self.col[:, b] ).abs() <= 1 - elif match("([1-9]) below_of ([1-9])"): + elif match("([1-9]) is_below ([1-9])"): a, b = g[0] return self.row[:, a] > self.row[:, b] - elif match("([1-9]) above ([1-9])"): + elif match("([1-9]) is_above ([1-9])"): a, b = g[0] return self.row[:, a] < self.row[:, b] - elif match("([1-9]) left_of ([1-9])"): + elif match("([1-9]) is_left_of ([1-9])"): a, b = g[0] return self.col[:, a] < self.col[:, b] - elif match("([1-9]) right_of ([1-9])"): + elif match("([1-9]) is_right_of ([1-9])"): a, b = g[0] return self.col[:, a] > self.col[:, b] - elif match("([1-9]) ([1-9]) diagonal"): + elif match("([1-9]) ([1-9]) parallel_to_diagonal"): a, b = g[0] return (self.col[:, a] - self.col[:, b]).abs() == ( self.row[:, a] - self.row[:, b] @@ -113,7 +112,7 @@ class FormalGrid: a, b = g[0] return self.row[:, a] == self.row[:, b] - elif match("([1-9]) ([1-9]) ([1-9]) aligned"): + elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) * ( self.row[:, a] - self.row[:, c] @@ -129,7 +128,15 @@ class FormalGrid: & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b]) ) - elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"): + elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) ** 2 + ( + self.row[:, a] - self.row[:, b] + ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + ( + self.row[:, a] - self.row[:, c] + ) ** 2 + + elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) ** 2 + ( self.row[:, a] - self.row[:, b] @@ -137,7 +144,7 @@ class FormalGrid: self.row[:, a] - self.row[:, c] ) ** 2 - elif match("([1-9]) ([1-9]) ([1-9]) right_angle"): + elif match("([1-9]) ([1-9]) ([1-9]) make_right_angle"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) * ( self.col[:, c] - self.col[:, b] @@ -185,25 +192,22 @@ class FormalGrid: ###################################################################### -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -grid = FormalGrid(device=device) + grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device) -grid_set = grid.new_grid_set( - [ - "4 top", - "4 right", - "1 top", - "1 left", - "1 left_of 2", - "2 left_of 3", - "1 2 4 right_angle", - "1 2 3 aligned", - "2 further_away_from 3 than 4", - ], -) + grid_set = grid.new_grid_set( + [ + "1 2 3 make_right_angle", + "2 3 4 make_right_angle", + "3 4 1 make_right_angle", + "2 is_equidistant_from 1 and 3", + "1 is_above 4", + ], + ) -print(f"There are {grid_set.long().sum().item()} configurations") + print(f"There are {grid_set.long().sum().item()} configurations") -for v in grid.views(grid_set): - print(v) + for v in grid.views(grid_set): + print(v) diff --git a/tinymnist.py b/tinymnist.py index 896477e..f662be6 100755 --- a/tinymnist.py +++ b/tinymnist.py @@ -70,14 +70,14 @@ test_input.sub_(mu).div_(std) start_time = time.perf_counter() for k in range(nb_epochs): - acc_loss = 0.0 + acc_train_loss = 0.0 for input, targets in zip( train_input.split(batch_size), train_targets.split(batch_size) ): output = model(input) loss = criterion(output, targets) - acc_loss += loss.item() + acc_train_loss += loss.item() * input.size(0) optimizer.zero_grad() loss.backward() @@ -92,6 +92,8 @@ for k in range(nb_epochs): test_error = nb_test_errors / test_input.size(0) duration = time.perf_counter() - start_time - print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%") + print( + f"loss {k} {duration:.02f}s {acc_train_loss/train_input.size(0):.02f} {test_error*100:.02f}%" + ) ###################################################################### -- 2.39.5 From 5ca289a1d537177b1d1b85159891c81a0de91cbc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 28 Oct 2024 13:16:09 +0100 Subject: [PATCH 09/16] Update. --- distributed.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100755 distributed.py diff --git a/distributed.py b/distributed.py new file mode 100755 index 0000000..adaa36f --- /dev/null +++ b/distributed.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python + +import time, socket, threading, struct, pickle + +import math, sys, argparse + +###################################################################### + +parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) + +parser.add_argument("--server_host", type=str, default=None) + +parser.add_argument("--port", type=int, default=12021) + +args = parser.parse_args() + +###################################################################### + + +class SocketConnection: + def __init__(self, established_socket, read_len=16384): + self.read_len = read_len + self.socket = established_socket + self.socket.setblocking(1) + self.buffer = b"" + self.SEND_LOCK = threading.Lock() + self.RECEIVE_LOCK = threading.Lock() + self.failed = False + + def send(self, x): + with self.SEND_LOCK: + data = pickle.dumps(x) + self.socket.send(struct.pack("!i", len(data))) + self.socket.sendall(data) + + def raw_read(self, l): + while len(self.buffer) < l: + d = self.socket.recv(self.read_len) + if d: + self.buffer += d + else: + raise EOFError() + + x = self.buffer[:l] + self.buffer = self.buffer[l:] + return x + + def receive(self): + with self.RECEIVE_LOCK: + l = struct.unpack("!i", self.raw_read(4))[0] + return pickle.loads(self.raw_read(l)) + + +###################################################################### + + +class CultureServer: + def __init__(self, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("0.0.0.0", port)) + s.listen(5) + self.nb_accepts = 0 + + while True: + client_socket, ip_and_port = s.accept() + link = SocketConnection(client_socket) + + threading.Thread( + target=self.client_loop, + kwargs={ + "link": link, + "nb": self.nb_accepts, + }, + daemon=True, + ).start() + + self.nb_accepts += 1 + + def client_loop(self, link, nb): + link.send(f"HELLO #{nb}") + try: + while True: + r = link.receive() + print(f'from #{nb} receive "{r}"') + link.send(f"ACK {r}") + if r == "BYE": + break + except EOFError: + print(f"closing #{nb} on EOF") + + +###################################################################### + + +class CultureClient: + def __init__(self, server_hostname, port): + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.connect((server_hostname, port)) + self.link = SocketConnection(server_socket) + + threading.Thread(target=self.receive, daemon=True).start() + + self.send() + # threading.Thread(target=self.send, daemon=True).start() + + def receive(self): + try: + while True: + x = self.link.receive() + print(f'CultureClient receive "{x}"') + except EOFError: + print(f"closing connection on EOF") + + def send(self): + try: + self.link.send(f"HELLO") + x = 0 + while True: + time.sleep(5) + print(f'CultureClient send "{x}"') + self.link.send(x) + x += 1 + except BrokenPipeError: + print(f"closing connection on broken pipe") + + +###################################################################### + +if args.server_host is None: + print(f"Starting server port {args.port}") + CultureServer(args.port) +else: + print(f"Starting client connecting to {args.server_host}:{args.port}") + CultureClient(args.server_host, args.port) + +###################################################################### -- 2.39.5 From 606a49cbf6f6f20963794e3947a20d2b0807b3e9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 28 Oct 2024 20:28:12 +0100 Subject: [PATCH 10/16] Update. --- distributed.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/distributed.py b/distributed.py index adaa36f..b5e98e9 100755 --- a/distributed.py +++ b/distributed.py @@ -69,26 +69,23 @@ class CultureServer: threading.Thread( target=self.client_loop, - kwargs={ - "link": link, - "nb": self.nb_accepts, - }, + kwargs={"link": link, "client_nb": self.nb_accepts}, daemon=True, ).start() self.nb_accepts += 1 - def client_loop(self, link, nb): - link.send(f"HELLO #{nb}") + def client_loop(self, link, client_nb): + link.send(f"HELLO #{client_nb}") try: while True: r = link.receive() - print(f'from #{nb} receive "{r}"') + print(f'from #{client_nb} receive "{r}"') link.send(f"ACK {r}") if r == "BYE": break except EOFError: - print(f"closing #{nb} on EOF") + print(f"closing #{client_nb} on EOF") ###################################################################### @@ -99,31 +96,33 @@ class CultureClient: server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.connect((server_hostname, port)) self.link = SocketConnection(server_socket) + self.buffer = [] threading.Thread(target=self.receive, daemon=True).start() - - self.send() - # threading.Thread(target=self.send, daemon=True).start() + self.loop() def receive(self): try: while True: - x = self.link.receive() - print(f'CultureClient receive "{x}"') + self.buffer.append(self.link.receive()) except EOFError: - print(f"closing connection on EOF") + print(f"** closing connection on EOF **") - def send(self): + def loop(self): try: - self.link.send(f"HELLO") - x = 0 while True: - time.sleep(5) - print(f'CultureClient send "{x}"') - self.link.send(x) - x += 1 + self.link.send(f"PING {time.localtime().tm_sec}") + + try: + while True: + print(self.buffer.pop(0)) + except IndexError: + pass + + time.sleep(1) + except BrokenPipeError: - print(f"closing connection on broken pipe") + print(f"** closing connection on broken pipe **") ###################################################################### -- 2.39.5 From 9c95bbe314e05c856d7a551b988d7a3fbad033b9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Oct 2024 08:06:43 +0100 Subject: [PATCH 11/16] Update. --- distributed.py | 89 +++++++++++++++++++++----------------------------- grid.py | 42 +++++++++++------------- 2 files changed, 56 insertions(+), 75 deletions(-) diff --git a/distributed.py b/distributed.py index b5e98e9..6fe3f6b 100755 --- a/distributed.py +++ b/distributed.py @@ -56,73 +56,58 @@ class SocketConnection: ###################################################################### -class CultureServer: - def __init__(self, port): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("0.0.0.0", port)) - s.listen(5) - self.nb_accepts = 0 +def create_server(port, reader, writer): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("0.0.0.0", port)) + s.listen(5) + self.nb_accepts = 0 - while True: - client_socket, ip_and_port = s.accept() - link = SocketConnection(client_socket) - - threading.Thread( - target=self.client_loop, - kwargs={"link": link, "client_nb": self.nb_accepts}, - daemon=True, - ).start() - - self.nb_accepts += 1 - - def client_loop(self, link, client_nb): - link.send(f"HELLO #{client_nb}") + def reading_loop(self, link, client_nb): try: while True: - r = link.receive() - print(f'from #{client_nb} receive "{r}"') - link.send(f"ACK {r}") - if r == "BYE": - break + reader(link.receive()) except EOFError: print(f"closing #{client_nb} on EOF") + while True: + client_socket, ip_and_port = s.accept() + link = SocketConnection(client_socket) -###################################################################### + threading.Thread( + target=writer, + kwargs={"link": link.send, "client_nb": self.nb_accepts}, + daemon=True, + ).start() + threading.Thread( + target=reading_loop, + kwargs={"link": link, "client_nb": self.nb_accepts}, + daemon=True, + ).start() -class CultureClient: - def __init__(self, server_hostname, port): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.connect((server_hostname, port)) - self.link = SocketConnection(server_socket) - self.buffer = [] + self.nb_accepts += 1 - threading.Thread(target=self.receive, daemon=True).start() - self.loop() - def receive(self): - try: - while True: - self.buffer.append(self.link.receive()) - except EOFError: - print(f"** closing connection on EOF **") +###################################################################### - def loop(self): - try: - while True: - self.link.send(f"PING {time.localtime().tm_sec}") - try: - while True: - print(self.buffer.pop(0)) - except IndexError: - pass +def create_client(server_hostname, port, reader): + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.connect((server_hostname, port)) + link = SocketConnection(server_socket) + + def reader_thread(reader): + while True: + reader(link.receive()) + + def writer(x): + self.link.send(x) - time.sleep(1) + threading.Thread( + target=reader_thread, kwargs={"reader": reader}, daemon=True + ).start() - except BrokenPipeError: - print(f"** closing connection on broken pipe **") + return writer ###################################################################### diff --git a/grid.py b/grid.py index 7168491..12ecba1 100755 --- a/grid.py +++ b/grid.py @@ -61,59 +61,59 @@ class FormalGrid: ###################################################################### def constraint_to_fun(self, constraint): - g = [None] + a, b, c = None, None, None def match(pattern): + nonlocal a, b, c r = re.search("^" + pattern + "$", constraint) if r: - g[0] = (int(x) - 1 for x in r.groups()) + g = tuple(int(x) - 1 for x in r.groups()) + a, b, c = g + (None,) * (3 - len(g)) return True else: return False if match("([1-9]) is_in_top_half"): - (a,) = g[0] return self.row[:, a] < self.grid_height // 2 + elif match("([1-9]) is_in_bottom_half"): - (a,) = g[0] return self.row[:, a] >= self.grid_height // 2 + elif match("([1-9]) is_on_left_side"): - (a,) = g[0] return self.col[:, a] < self.grid_width // 2 + elif match("([1-9]) is_on_right_side"): - (a,) = g[0] return self.col[:, a] >= self.grid_width // 2 + elif match("([1-9]) next_to ([1-9])"): - a, b = g[0] return (self.row[:, a] - self.row[:, b]).abs() + ( self.col[:, a] - self.col[:, b] ).abs() <= 1 + elif match("([1-9]) is_below ([1-9])"): - a, b = g[0] return self.row[:, a] > self.row[:, b] + elif match("([1-9]) is_above ([1-9])"): - a, b = g[0] return self.row[:, a] < self.row[:, b] + elif match("([1-9]) is_left_of ([1-9])"): - a, b = g[0] return self.col[:, a] < self.col[:, b] + elif match("([1-9]) is_right_of ([1-9])"): - a, b = g[0] return self.col[:, a] > self.col[:, b] + elif match("([1-9]) ([1-9]) parallel_to_diagonal"): - a, b = g[0] return (self.col[:, a] - self.col[:, b]).abs() == ( self.row[:, a] - self.row[:, b] ).abs() + elif match("([1-9]) ([1-9]) vertical"): - a, b = g[0] return self.col[:, a] == self.col[:, b] + elif match("([1-9]) ([1-9]) horizontal"): - a, b = g[0] return self.row[:, a] == self.row[:, b] elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): - a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) * ( self.row[:, a] - self.row[:, c] ) - (self.row[:, a] - self.row[:, b]) * ( @@ -121,7 +121,6 @@ class FormalGrid: ) == 0 elif match("([1-9]) middle_of ([1-9]) ([1-9])"): - a, b, c = g[0] return ( grid_set & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b]) @@ -129,7 +128,6 @@ class FormalGrid: ) elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"): - a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) ** 2 + ( self.row[:, a] - self.row[:, b] ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + ( @@ -137,15 +135,13 @@ class FormalGrid: ) ** 2 elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"): - a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) ** 2 + ( self.row[:, a] - self.row[:, b] ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + ( self.row[:, a] - self.row[:, c] ) ** 2 - elif match("([1-9]) ([1-9]) ([1-9]) make_right_angle"): - a, b, c = g[0] + elif match("([1-9]) ([1-9]) ([1-9]) form_a_right_angle"): return (self.col[:, a] - self.col[:, b]) * ( self.col[:, c] - self.col[:, b] ) + (self.row[:, a] - self.row[:, b]) * ( @@ -199,9 +195,9 @@ if __name__ == "__main__": grid_set = grid.new_grid_set( [ - "1 2 3 make_right_angle", - "2 3 4 make_right_angle", - "3 4 1 make_right_angle", + "1 2 3 form_a_right_angle", + "2 3 4 form_a_right_angle", + "3 4 1 form_a_right_angle", "2 is_equidistant_from 1 and 3", "1 is_above 4", ], -- 2.39.5 From 96429a6891d09c994a13a0b6969c7f82c45945a7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Oct 2024 08:21:47 +0100 Subject: [PATCH 12/16] Update. --- grid.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/grid.py b/grid.py index 12ecba1..d06802d 100755 --- a/grid.py +++ b/grid.py @@ -11,9 +11,8 @@ import math, re -import torch, torchvision +import torch -from torch import nn from torch.nn import functional as F ###################################################################### @@ -49,7 +48,7 @@ class FormalGrid: nb += F.one_hot( (self.row[:, s] * self.grid_width + self.col[:, s]).long(), num_classes=self.grid_height * self.grid_width, - ).to(torch.uint8) + ).to(torch.int8) self.master_grid_set = nb.max(dim=1).values <= 1 def new_grid_set(self, constraints=None): @@ -102,15 +101,15 @@ class FormalGrid: elif match("([1-9]) is_right_of ([1-9])"): return self.col[:, a] > self.col[:, b] - elif match("([1-9]) ([1-9]) parallel_to_diagonal"): + elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"): return (self.col[:, a] - self.col[:, b]).abs() == ( self.row[:, a] - self.row[:, b] ).abs() - elif match("([1-9]) ([1-9]) vertical"): + elif match("([1-9]) ([1-9]) is_vertical"): return self.col[:, a] == self.col[:, b] - elif match("([1-9]) ([1-9]) horizontal"): + elif match("([1-9]) ([1-9]) is_horizontal"): return self.row[:, a] == self.row[:, b] elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): @@ -193,13 +192,22 @@ if __name__ == "__main__": grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device) + # grid_set = grid.new_grid_set( + # [ + # "1 2 3 form_a_right_angle", + # "2 3 4 form_a_right_angle", + # "3 4 1 form_a_right_angle", + # "2 is_equidistant_from 1 and 3", + # "1 is_above 4", + # ], + # ) + grid_set = grid.new_grid_set( [ - "1 2 3 form_a_right_angle", - "2 3 4 form_a_right_angle", - "3 4 1 form_a_right_angle", - "2 is_equidistant_from 1 and 3", - "1 is_above 4", + "1 2 3 are_aligned", + "2 3 is_parallel_to_diagonal", + "4 1 is_vertical", + "3 4 is_horizontal", ], ) -- 2.39.5 From 97f71dd09a28d1fd3a7d62a323169e48ea6a38f7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Oct 2024 09:21:05 +0100 Subject: [PATCH 13/16] Update. --- distributed.py | 61 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/distributed.py b/distributed.py index 6fe3f6b..91d33ee 100755 --- a/distributed.py +++ b/distributed.py @@ -10,7 +10,7 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--server_host", type=str, default=None) +parser.add_argument("--server", type=str, default=None) parser.add_argument("--port", type=int, default=12021) @@ -56,44 +56,50 @@ class SocketConnection: ###################################################################### -def create_server(port, reader, writer): +def start_server(port, core, reader): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("0.0.0.0", port)) s.listen(5) - self.nb_accepts = 0 + nb_accepts = 0 - def reading_loop(self, link, client_nb): + def reader_thread(reader, link, client_nb): try: while True: - reader(link.receive()) + reader(link.receive(), client_nb) except EOFError: - print(f"closing #{client_nb} on EOF") + print(f"closing reader #{client_nb} on EOFError") + + def core_thread(writer, client_nb): + try: + core(writer, client_nb) + except BrokenPipeError: + print(f"closing core #{client_nb} on BrokenPipeError") while True: client_socket, ip_and_port = s.accept() link = SocketConnection(client_socket) threading.Thread( - target=writer, - kwargs={"link": link.send, "client_nb": self.nb_accepts}, + target=core_thread, + kwargs={"writer": link.send, "client_nb": nb_accepts}, daemon=True, ).start() threading.Thread( - target=reading_loop, - kwargs={"link": link, "client_nb": self.nb_accepts}, + target=reader_thread, + kwargs={"reader": reader, "link": link, "client_nb": nb_accepts}, daemon=True, ).start() - self.nb_accepts += 1 + nb_accepts += 1 ###################################################################### -def create_client(server_hostname, port, reader): +def create_client(servername, port, reader): server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.connect((server_hostname, port)) + server_socket.connect((servername, port)) link = SocketConnection(server_socket) def reader_thread(reader): @@ -101,7 +107,7 @@ def create_client(server_hostname, port, reader): reader(link.receive()) def writer(x): - self.link.send(x) + link.send(x) threading.Thread( target=reader_thread, kwargs={"reader": reader}, daemon=True @@ -112,11 +118,30 @@ def create_client(server_hostname, port, reader): ###################################################################### -if args.server_host is None: +if args.server is None: print(f"Starting server port {args.port}") - CultureServer(args.port) + + def reader(x, nb): + print(f'Server received from client #{nb} "{x}"') + + def core(writer, client_nb): + writer(f"HELLO {client_nb}") + while True: + writer(f"PONG {time.localtime().tm_sec}") + time.sleep(3) + + start_server(port=args.port, core=core, reader=reader) + else: - print(f"Starting client connecting to {args.server_host}:{args.port}") - CultureClient(args.server_host, args.port) + print(f"Starting client connecting to {args.server}:{args.port}") + + def reader(x): + print(f'Client received "{x}"') + + writer = create_client(args.server, args.port, reader) + + while True: + writer(f"PING {time.localtime().tm_sec}") + time.sleep(3) ###################################################################### -- 2.39.5 From 95e7d5dcc8296e7ddb6326cb92c29d08bfb548b3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 29 Oct 2024 09:34:41 +0100 Subject: [PATCH 14/16] Update. --- distributed.py | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/distributed.py b/distributed.py index 91d33ee..648756e 100755 --- a/distributed.py +++ b/distributed.py @@ -2,7 +2,7 @@ import time, socket, threading, struct, pickle -import math, sys, argparse +import argparse ###################################################################### @@ -62,32 +62,36 @@ def start_server(port, core, reader): s.listen(5) nb_accepts = 0 - def reader_thread(reader, link, client_nb): + def threadable_reader(reader, receiver, client_id): try: while True: - reader(link.receive(), client_nb) + reader(link.receive(), client_id) except EOFError: - print(f"closing reader #{client_nb} on EOFError") + print(f"** closing reader #{client_id} on EOFError **") - def core_thread(writer, client_nb): + def threadable_core(sender, client_id): try: - core(writer, client_nb) + core(sender, client_id) except BrokenPipeError: - print(f"closing core #{client_nb} on BrokenPipeError") + print(f"** closing core #{client_id} on BrokenPipeError **") while True: client_socket, ip_and_port = s.accept() link = SocketConnection(client_socket) threading.Thread( - target=core_thread, - kwargs={"writer": link.send, "client_nb": nb_accepts}, + target=threadable_core, + kwargs={"sender": link.send, "client_id": nb_accepts}, daemon=True, ).start() threading.Thread( - target=reader_thread, - kwargs={"reader": reader, "link": link, "client_nb": nb_accepts}, + target=threadable_reader, + kwargs={ + "reader": reader, + "receiver": link.receive, + "client_id": nb_accepts, + }, daemon=True, ).start() @@ -102,7 +106,7 @@ def create_client(servername, port, reader): server_socket.connect((servername, port)) link = SocketConnection(server_socket) - def reader_thread(reader): + def threadable_reader(reader): while True: reader(link.receive()) @@ -110,7 +114,7 @@ def create_client(servername, port, reader): link.send(x) threading.Thread( - target=reader_thread, kwargs={"reader": reader}, daemon=True + target=threadable_reader, kwargs={"reader": reader}, daemon=True ).start() return writer @@ -121,11 +125,11 @@ def create_client(servername, port, reader): if args.server is None: print(f"Starting server port {args.port}") - def reader(x, nb): - print(f'Server received from client #{nb} "{x}"') + def reader(obj, client_id): + print(f'Server received from client #{client_id} "{obj}"') - def core(writer, client_nb): - writer(f"HELLO {client_nb}") + def core(writer, client_id): + writer(f"HELLO {client_id}") while True: writer(f"PONG {time.localtime().tm_sec}") time.sleep(3) @@ -135,8 +139,8 @@ if args.server is None: else: print(f"Starting client connecting to {args.server}:{args.port}") - def reader(x): - print(f'Client received "{x}"') + def reader(obj): + print(f'Client received from server "{obj}"') writer = create_client(args.server, args.port, reader) -- 2.39.5 From 32945dec8f4c0b0406c58c7c9d21bdc0b9c6e73a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Jun 2025 14:24:17 +0200 Subject: [PATCH 15/16] Update. --- tinyae.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tinyae.py b/tinyae.py index b4f3aba..0baa5a2 100755 --- a/tinyae.py +++ b/tinyae.py @@ -124,20 +124,22 @@ test_input.sub_(mu).div_(std) ###################################################################### -for epoch in range(args.nb_epochs): - acc_loss = 0 +for n_epoch in range(args.nb_epochs): + acc_train_loss = 0 for input in train_input.split(args.batch_size): output = model(input) - loss = 0.5 * (output - input).pow(2).sum() / input.size(0) + train_loss = F.mse_loss(output, input) optimizer.zero_grad() - loss.backward() + train_loss.backward() optimizer.step() - acc_loss += loss.item() + acc_train_loss += train_loss.detach().item() * input.size(0) - log_string("acc_loss {:d} {:f}.".format(epoch, acc_loss)) + train_loss = acc_train_loss / train_input.size(0) + + log_string(f"train_loss {n_epoch} {train_loss}") ###################################################################### -- 2.39.5 From dbc13bfbcdb12f2becf82ec77042c4eeeb7bb4fc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Jun 2025 17:14:20 +0200 Subject: [PATCH 16/16] Update. --- grid.py | 210 ++++++++++++++++++++++++++++++++++++++----------- picocrafter.py | 6 +- tinyae.py | 4 +- tinymnist.py | 22 +++++- 4 files changed, 188 insertions(+), 54 deletions(-) diff --git a/grid.py b/grid.py index d06802d..ac0ebe0 100755 --- a/grid.py +++ b/grid.py @@ -9,7 +9,7 @@ # This code implement a simple system to manipulate formal # specifications of tokens on a grid. -import math, re +import math, re, random import torch @@ -61,6 +61,7 @@ class FormalGrid: def constraint_to_fun(self, constraint): a, b, c = None, None, None + col, row = self.col, self.row def match(pattern): nonlocal a, b, c @@ -73,79 +74,70 @@ class FormalGrid: return False if match("([1-9]) is_in_top_half"): - return self.row[:, a] < self.grid_height // 2 + return row[:, a] < self.grid_height // 2 elif match("([1-9]) is_in_bottom_half"): - return self.row[:, a] >= self.grid_height // 2 + return row[:, a] >= self.grid_height // 2 elif match("([1-9]) is_on_left_side"): - return self.col[:, a] < self.grid_width // 2 + return col[:, a] < self.grid_width // 2 elif match("([1-9]) is_on_right_side"): - return self.col[:, a] >= self.grid_width // 2 + return col[:, a] >= self.grid_width // 2 elif match("([1-9]) next_to ([1-9])"): - return (self.row[:, a] - self.row[:, b]).abs() + ( - self.col[:, a] - self.col[:, b] - ).abs() <= 1 + return (row[:, a] - row[:, b]).abs() + (col[:, a] - col[:, b]).abs() == 1 elif match("([1-9]) is_below ([1-9])"): - return self.row[:, a] > self.row[:, b] + return row[:, a] > row[:, b] elif match("([1-9]) is_above ([1-9])"): - return self.row[:, a] < self.row[:, b] + return row[:, a] < row[:, b] elif match("([1-9]) is_left_of ([1-9])"): - return self.col[:, a] < self.col[:, b] + return col[:, a] < col[:, b] elif match("([1-9]) is_right_of ([1-9])"): - return self.col[:, a] > self.col[:, b] + return col[:, a] > col[:, b] elif match("([1-9]) ([1-9]) is_parallel_to_diagonal"): - return (self.col[:, a] - self.col[:, b]).abs() == ( - self.row[:, a] - self.row[:, b] - ).abs() + return (col[:, a] - col[:, b]).abs() == (row[:, a] - row[:, b]).abs() elif match("([1-9]) ([1-9]) is_vertical"): - return self.col[:, a] == self.col[:, b] + return col[:, a] == col[:, b] elif match("([1-9]) ([1-9]) is_horizontal"): - return self.row[:, a] == self.row[:, b] + return row[:, a] == row[:, b] elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): - return (self.col[:, a] - self.col[:, b]) * ( - self.row[:, a] - self.row[:, c] - ) - (self.row[:, a] - self.row[:, b]) * ( - self.col[:, a] - self.col[:, c] - ) == 0 + return (col[:, a] - col[:, b]) * (row[:, a] - row[:, c]) - ( + row[:, a] - row[:, b] + ) * (col[:, a] - col[:, c]) == 0 elif match("([1-9]) middle_of ([1-9]) ([1-9])"): - return ( - grid_set - & (self.col[:, a] + self.col[:, c] == 2 * self.col[:, b]) - & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b]) + return (col[:, b] + col[:, a] == 2 * col[:, b]) & ( + row[:, b] + row[:, a] == 2 * row[:, b] ) elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"): - return (self.col[:, a] - self.col[:, b]) ** 2 + ( - self.row[:, a] - self.row[:, b] - ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + ( - self.row[:, a] - self.row[:, c] - ) ** 2 - - elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"): - return (self.col[:, a] - self.col[:, b]) ** 2 + ( - self.row[:, a] - self.row[:, b] - ) ** 2 > (self.col[:, a] - self.col[:, c]) ** 2 + ( - self.row[:, a] - self.row[:, c] - ) ** 2 + return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 == ( + col[:, a] - col[:, c] + ) ** 2 + (row[:, a] - row[:, c]) ** 2 + + elif match("([1-9]) is_further_from ([1-9]) than_from ([1-9])"): + return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 > ( + col[:, a] - col[:, c] + ) ** 2 + (row[:, a] - row[:, c]) ** 2 + + elif match("([1-9]) is_closer_to ([1-9]) than_to ([1-9])"): + return (col[:, a] - col[:, b]) ** 2 + (row[:, a] - row[:, b]) ** 2 < ( + col[:, a] - col[:, c] + ) ** 2 + (row[:, a] - row[:, c]) ** 2 elif match("([1-9]) ([1-9]) ([1-9]) form_a_right_angle"): - return (self.col[:, a] - self.col[:, b]) * ( - self.col[:, c] - self.col[:, b] - ) + (self.row[:, a] - self.row[:, b]) * ( - self.row[:, c] - self.row[:, b] - ) == 0 + return (col[:, a] - col[:, b]) * (col[:, c] - col[:, b]) + ( + row[:, a] - row[:, b] + ) * (row[:, c] - row[:, b]) == 0 else: raise ValueError(f"Unknown type of constraint {constraint}") @@ -184,13 +176,138 @@ class FormalGrid: v += " ".join(["-" if n == 0 else str(n.item()) for n in r]) + "\n" yield v + def random_property(self): + a, b, c = random.sample(list(range(1, self.nb_symbols + 1)), 3) + + sb, sc = min(b, c), max(b, c) + + ta, tb, tc = sorted((a, b, c)) + + l = ( + [ + f"{a} is_in_top_half", + f"{a} is_in_bottom_half", + f"{a} is_on_left_side", + f"{a} is_on_right_side", + ] + + [ + f"{a} is_below {b}", + f"{a} is_above {b}", + f"{a} is_left_of {b}", + f"{a} is_right_of {b}", + f"{sb} next_to {sc}", + ] + + [ + f"{sb} {sc} is_parallel_to_diagonal", + f"{sb} {sc} is_vertical", + f"{sb} {sc} is_horizontal", + ] + * 2 + + [ + f"{ta} {tb} {tc} are_aligned", + f"{a} middle_of {sb} {sc}", + f"{ta} {tb} {tc} form_a_right_angle", + ] + * 3 + + [ + f"{a} is_equidistant_from {sb} and {sc}", + f"{a} is_further_from {b} than_from {c}", + f"{a} is_closer_to {b} than_to {c}", + ] + ) + + return random.choice(l) + ###################################################################### if __name__ == "__main__": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device) + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + # grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device) + # grid_set = grid.new_grid_set(["4 is_equidistant_from 2 and 3", "2 4 is_parallel_to_diagonal"]) + # print(next(iter(grid.views(grid_set)))) + # exit(0) + + def proof_depth(steps, c): + a = steps.get(c) + if a is None: + return 0 + else: + c1, c2 = a + return max(proof_depth(steps, c1), proof_depth(steps, c2)) + + def generate_proof(grid): + while True: + constraints = [grid.random_property() for _ in range(10)] + grid_set = grid.new_grid_set(constraints) + if grid_set.any(): + break + + mg = grid.master_grid_set + + print(constraints) + + initial = constraints.copy() + + steps = {} + + for _ in range(1000): + c1, c2 = random.sample(constraints, 2) + f1, f2 = grid.constraint_to_fun(c1), grid.constraint_to_fun(c2) + for _ in range(100): + c = grid.random_property() + if c not in constraints: + f = grid.constraint_to_fun(c) + if ( + (mg & f1 & ~f).any() + and (mg & f2 & ~f).any() + and (mg & f1 & f2 & f).any() + and not (mg & f1 & f2 & ~f).any() + ): + constraints.append(c) + print(c1, "and", c2, "=>", c) + steps[c] = (c1, c2) + # print(next(iter(grid.views(grid.new_grid_set([c1, c2]))))) + # print("we have", constraints) + # proof.append(c1 + " and " + c2 + " hence " + c) + break + + if steps.keys() and max([proof_depth(steps, c) for c in steps.keys()]) >= 3: + + break + + return initial, steps + + grid = FormalGrid(grid_height=7, grid_width=7, nb_symbols=4, device=device) + + initial, steps = generate_proof(grid) + + print(" ; ".join(initial)) + + def proof(c, indent=""): + a = steps.get(c) + if a is None: + print(f"{indent}{c} is given") + else: + print(f"{indent}{c} since") + c1, c2 = a + proof(c1, indent + " ") + proof(c2, indent + " ") + + print(" ; ".join(initial)) + + for c in steps.keys(): + proof(c) + print() + + exit(0) # grid_set = grid.new_grid_set( # [ @@ -208,6 +325,9 @@ if __name__ == "__main__": "2 3 is_parallel_to_diagonal", "4 1 is_vertical", "3 4 is_horizontal", + "3 is_left_of 2", + "1 is_below 4", + "2 is_right_of 4", ], ) diff --git a/picocrafter.py b/picocrafter.py index 23d93b2..001bb81 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -227,9 +227,9 @@ class PicroCrafterEnvironment: r = u.sort(dim=-1, descending=True).indices[:, : len(z)] q *= self.tile2id["#"] - q[ - torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r - ] = torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] + q[torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r] = ( + torch.tensor([self.tile2id[c] for c in z], device=q.device)[None, :] + ) if world_margin > 0: r = m.new_full( diff --git a/tinyae.py b/tinyae.py index 0baa5a2..806559e 100755 --- a/tinyae.py +++ b/tinyae.py @@ -92,8 +92,8 @@ class AutoEncoder(nn.Module): return self.decoder(z.view(z.size(0), -1, 1, 1)) def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) + x = self.encode(x) + x = self.decode(x) return x diff --git a/tinymnist.py b/tinymnist.py index f662be6..19b9387 100755 --- a/tinymnist.py +++ b/tinymnist.py @@ -1,5 +1,8 @@ #!/usr/bin/env python +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_PRE: source ~/venv/pytorch/bin/activate + # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ @@ -14,7 +17,12 @@ lr, nb_epochs, batch_size = 1e-1, 10, 100 data_dir = os.environ.get("PYTORCH_DATA_DIR") or "./data/" -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") ###################################################################### @@ -52,7 +60,7 @@ model = SomeLeNet() nb_parameters = sum(p.numel() for p in model.parameters()) -print(f"nb_parameters {nb_parameters}") +print(f"device {device} nb_parameters {nb_parameters}") optimizer = torch.optim.SGD(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() @@ -83,17 +91,23 @@ for k in range(nb_epochs): loss.backward() optimizer.step() + acc_test_loss = 0.0 nb_test_errors = 0 for input, targets in zip( test_input.split(batch_size), test_targets.split(batch_size) ): - wta = model(input).argmax(1) + output = model(input) + loss = criterion(output, targets) + acc_test_loss += loss.item() * input.size(0) + + wta = output.argmax(1) nb_test_errors += (wta != targets).long().sum() + test_error = nb_test_errors / test_input.size(0) duration = time.perf_counter() - start_time print( - f"loss {k} {duration:.02f}s {acc_train_loss/train_input.size(0):.02f} {test_error*100:.02f}%" + f"loss {k} {duration:.02f}s acc_train_loss {acc_train_loss/train_input.size(0):.02f} test_loss {acc_test_loss/test_input.size(0):.02f} test_error {test_error*100:.02f}%" ) ###################################################################### -- 2.39.5