# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/
-import math
+import math, warnings
import torch
######################################################################
+class BlockRandomPositionalEncoding(nn.Module):
+ def __init__(self, dim, block_size, nb_blocks):
+ super().__init__()
+ self.pe_inside = nn.Parameter(torch.randn(1, block_size, dim) / math.sqrt(dim))
+ self.pe_per_blocks = nn.Parameter(
+ torch.randn(1, nb_blocks, dim) / math.sqrt(dim)
+ )
+
+ def forward(self, x):
+ pe = self.pe_inside.repeat(
+ x.size(0), self.pe_per_blocks.size(1), 1
+ ) + self.pe_per_blocks.repeat_interleave(self.pe_inside.size(1), dim=1).repeat(
+ x.size(0), 1, 1
+ )
+ y = x + pe
+ return y
+
+
+######################################################################
+
+
class WithResidual(nn.Module):
def __init__(self, *f):
super().__init__()
def forward(self, x):
x = self.embedding(x)
+
+ warnings.warn("flipping order for symmetry check", RuntimeWarning)
+ x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
x = self.positional_encoding(x)
+ x = torch.cat([x[:, 200:], x[:, :200]], dim=1)
+
x = self.trunk(x)
x = self.readout(x)
+
return x
)
def forward(self, x_q):
+ #!!!!!!!!!!!!!!!!!!!!
+ # x_q = torch.cat([x_q[:,200:,:], x_q[:,:200,:]],dim=1)
+
T, S = x_q.size(1), self.x_star.size(0)
nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
x = self.trunk_joint(x)
f, x = x[:, :S, :], x[:, S:, :]
-
- if hasattr(self, "forced_f") and self.forced_f is not None:
- f = self.forced_f.clone()
-
- self.pred_f = f.clone()
-
x = x.reshape(nb * nc, T // nc, dim)
f = f.repeat(nc, 1, 1)
x = torch.cat([f, x], dim=1)
x = x[:, S:, :]
x = x.reshape(nb, T, dim)
+ #!!!!!!!!!!!!!!!!!!!!
+ # x = torch.cat([x[:,200:,:], x[:,:200,:]],dim=1)
+
return x
)
loss = (loss_per_token * masks).mean()
- if args.test == "aebn":
- error = 0
- for m in model.modules():
- if hasattr(m, "error"):
- error = error + m.error
- loss = loss + error
-
acc_loss += loss.item() * imt.size(0)
nb_samples += imt.size(0)
if nb_samples % args.batch_size == 0:
model.optimizer.step()
- if args.test == "aebn":
- nb_me = []
- for m in model.modules():
- if hasattr(m, "nb_me"):
- nb_me.append(m.nb_me.item())
-
- log_string(f"{label}_error {n_epoch} model {model.id} {error} nb_me {nb_me}")
-
log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
# Save some images of the ex nihilo generation of the four grids
result = ae_generate(model, 150, local_device=local_device).to("cpu")
+
problem.save_quizzes_as_image(
args.result_dir,
f"culture_generation_{n_epoch}_{model.id}.png",
len_max=1e4,
)
- model.id = 0
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- model.test_accuracy = 0.0
- model.nb_epochs = 0
+ model.positional_encoding = attae.BlockRandomPositionalEncoding(
+ args.dim_model, 100, 4
+ )
model.trunk = attae.Reasoning(
nb_f_tokens=25,
attention_dropout=args.dropout,
)
- # model.trunk = model.trunk[: len(model.trunk) // 2] + nn.Sequential(
- # attae.LearningToBeMe(f, g, 1e-1)
- # )
-
- # model.id = 0
- # model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- # model.test_accuracy = 0.0
- # model.nb_epochs = 0
-
+ model.id = 0
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.test_accuracy = 0.0
+ model.nb_epochs = 0
for n_epoch in range(args.nb_epochs):
one_complete_epoch(