return y
+def attention_block(dim_model, dim_keys, nb_heads, dropout):
+ return WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ ),
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ )
+
+
+def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
+ return WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ )
+
+
class MyAttentionAE(nn.Module):
def __init__(
self,
trunk_blocks = []
for b in range(nb_blocks):
- # if b == nb_blocks//2:
- # trunk_blocks += [
- # QKVAttention(
- # dim_in=dim_model,
- # dim_qk=dim_keys,
- # dim_v=dim_model // nb_heads,
- # nb_heads=nb_heads,
- # attention_dropout=dropout,
- # ),
- # VaswaniPositionalEncoding(len_max=1e5)
- # ]
-
trunk_blocks += [
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
+ attention_block(dim_model, dim_keys, nb_heads, dropout),
+ ffw_block(dim_model, dim_hidden, nb_heads, dropout),
]
self.trunk = nn.Sequential(*trunk_blocks)
return bs
+######################################################################
+
+# f = phi(A, f(A)) + phi(B, f(B))
+# \hat{f(A)} = psi(A, f)
+# \hat{A} = psi_inv(f(A), f)
+# \hat{f(B)} = psi(B, f)
+# \hat{B} = psi_inv(f(B), f)
+
+
+def attention_layer(dim_model, dim_keys, nb_heads, dropout):
+ return WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ ),
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ )
+
+
+class FunctionalAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1024,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = CacheWrapper(
+ nn.Sequential(
+ MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+ ),
+ )
+
+ # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+ self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+
+ def trunk(nb, bottom=True):
+ trunk_blocks = []
+
+ la = [
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ VaswaniPositionalEncoding(len_max=1e5),
+ ]
+
+ # if not bottom:
+ # trunk_blocks += la
+
+ for b in range(nb):
+ trunk_blocks += [
+ attention_block(dim_model, dim_keys, nb_heads, dropout),
+ ffw_block(dim_model, dim_hidden, nb_heads, dropout),
+ ]
+
+ # if bottom:
+ # trunk_blocks += la
+
+ return nn.Sequential(*trunk_blocks)
+
+ self.phi = trunk(nb_blocks // 2, bottom=True)
+ nb_f_tokens = 200
+ self.f_tokens = nn.Parameter(
+ torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens)
+ )
+ self.psi = trunk(nb_blocks // 2, bottom=False)
+ self.psi_inv = trunk(nb_blocks // 2, bottom=False)
+ self.internal_pe = VaswaniPositionalEncoding(len_max=1e5)
+
+ self.readout = CacheWrapper(
+ nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ )
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, bs):
+ def cat(*x):
+ return BracketedSequence(torch.cat(x, dim=1))
+
+ if torch.is_tensor(bs):
+ return self.forward(BracketedSequence(bs)).x
+ bs = self.embedding(bs)
+ bs = self.positional_encoding(bs)
+
+ x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1)
+
+ K = self.f_tokens.size(1)
+ N, L = x_A.size()[:2]
+
+ ft = self.f_tokens.expand(N, -1, -1)
+
+ theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
+ theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
+
+ hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
+ hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
+
+ hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L]
+ hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L]
+
+ bs = cat(hat_A, hat_f_A, hat_B, hat_f_B)
+
+ bs = self.readout(bs)
+ return bs
+
+
######################################################################
nb_iterations = 25
def model_ae_proba_solutions(model, input):
- loss = 0
+ record = []
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=input, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
- targets, logits = targets_and_prediction(
- probs_iterations, model, input, mask_generate
- )
- loss_per_token = F.cross_entropy(
- logits.transpose(1, 2), targets, reduction="none"
- )
- loss += (loss_per_token * mask_generate).sum(dim=1)
+ for q in input.split(args.batch_size):
+ loss = 0
+
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ )
+ targets, logits = targets_and_prediction(
+ probs_iterations, model, q, mask_generate
+ )
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss += (loss_per_token * mask_generate).sum(dim=1)
+ record.append(loss)
+
+ loss = torch.cat(record, dim=0)
return (-loss).exp()
models = []
for i in range(args.nb_models):
- model = MyAttentionAE(
+ # model = MyAttentionAE(
+ model = FunctionalAE(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
######################################################################
-for n_epoch in range(args.nb_epochs):
+for n_epoch in range(current_epoch, args.nb_epochs):
+ start_time = time.perf_counter()
+
state = {
"current_epoch": n_epoch,
# "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
# --------------------------------------------------------------------
- one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
- exit(0)
+ # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
+ # exit(0)
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
)
log_string(f"wrote {filename}")
+
+ # --------------------------------------------------------------------
+
+ duration = time.perf_counter() - start_time
+ str_duration = ""
+ if duration >= 60:
+ str_duration += f"{int(duration//60)}min"
+ duration = duration % 60
+ str_duration += f"{duration:.01f}s"
+ log_string(f"epoch_duration {str_duration}")