From 1692b231e06fbd05cef228e5b492cb973c7a047d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 9 Oct 2024 22:54:00 +0200 Subject: [PATCH] Update. --- attae.py | 41 ++++++++++++++++++++++++++++++++++------- main.py | 35 +++++++---------------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/attae.py b/attae.py index 5d9da2e..3eb6c4e 100755 --- a/attae.py +++ b/attae.py @@ -3,7 +3,7 @@ # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ -import math +import math, warnings import torch @@ -33,6 +33,27 @@ class VaswaniPositionalEncoding(nn.Module): ###################################################################### +class BlockRandomPositionalEncoding(nn.Module): + def __init__(self, dim, block_size, nb_blocks): + super().__init__() + self.pe_inside = nn.Parameter(torch.randn(1, block_size, dim) / math.sqrt(dim)) + self.pe_per_blocks = nn.Parameter( + torch.randn(1, nb_blocks, dim) / math.sqrt(dim) + ) + + def forward(self, x): + pe = self.pe_inside.repeat( + x.size(0), self.pe_per_blocks.size(1), 1 + ) + self.pe_per_blocks.repeat_interleave(self.pe_inside.size(1), dim=1).repeat( + x.size(0), 1, 1 + ) + y = x + pe + return y + + +###################################################################### + + class WithResidual(nn.Module): def __init__(self, *f): super().__init__() @@ -171,9 +192,15 @@ class AttentionAE(nn.Module): def forward(self, x): x = self.embedding(x) + + warnings.warn("flipping order for symmetry check", RuntimeWarning) + x = torch.cat([x[:, 200:], x[:, :200]], dim=1) x = self.positional_encoding(x) + x = torch.cat([x[:, 200:], x[:, :200]], dim=1) + x = self.trunk(x) x = self.readout(x) + return x @@ -330,6 +357,9 @@ class Reasoning(nn.Module): ) def forward(self, x_q): + #!!!!!!!!!!!!!!!!!!!! + # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1) + T, S = x_q.size(1), self.x_star.size(0) nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks @@ -339,12 +369,6 @@ class Reasoning(nn.Module): x = self.trunk_joint(x) f, x = x[:, :S, :], x[:, S:, :] - - if hasattr(self, "forced_f") and self.forced_f is not None: - f = self.forced_f.clone() - - self.pred_f = f.clone() - x = x.reshape(nb * nc, T // nc, dim) f = f.repeat(nc, 1, 1) x = torch.cat([f, x], dim=1) @@ -353,6 +377,9 @@ class Reasoning(nn.Module): x = x[:, S:, :] x = x.reshape(nb, T, dim) + #!!!!!!!!!!!!!!!!!!!! + # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1) + return x diff --git a/main.py b/main.py index 618a62e..ba5c6e2 100755 --- a/main.py +++ b/main.py @@ -558,13 +558,6 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): ) loss = (loss_per_token * masks).mean() - if args.test == "aebn": - error = 0 - for m in model.modules(): - if hasattr(m, "error"): - error = error + m.error - loss = loss + error - acc_loss += loss.item() * imt.size(0) nb_samples += imt.size(0) @@ -574,14 +567,6 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): if nb_samples % args.batch_size == 0: model.optimizer.step() - if args.test == "aebn": - nb_me = [] - for m in model.modules(): - if hasattr(m, "nb_me"): - nb_me.append(m.nb_me.item()) - - log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}") - log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") @@ -613,6 +598,7 @@ def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_de # Save some images of the ex nihilo generation of the four grids result = ae_generate(model, 150, local_device=local_device).to("cpu") + problem.save_quizzes_as_image( args.result_dir, f"culture_generation_{n_epoch}_{model.id}.png", @@ -941,10 +927,9 @@ if args.test == "aebn": len_max=1e4, ) - model.id = 0 - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model.test_accuracy = 0.0 - model.nb_epochs = 0 + model.positional_encoding = attae.BlockRandomPositionalEncoding( + args.dim_model, 100, 4 + ) model.trunk = attae.Reasoning( nb_f_tokens=25, @@ -957,16 +942,10 @@ if args.test == "aebn": attention_dropout=args.dropout, ) - # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential( - # attae.LearningToBeMe(f, g, 1e-1) - # ) - - # model.id = 0 - # model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - # model.test_accuracy = 0.0 - # model.nb_epochs = 0 - + model.id = 0 model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.test_accuracy = 0.0 + model.nb_epochs = 0 for n_epoch in range(args.nb_epochs): one_complete_epoch( -- 2.39.5