From 80debccd24ca12e620382185ed44266041064c64 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 29 Aug 2024 08:07:24 +0200 Subject: [PATCH] Update. --- main.py | 242 ++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 193 insertions(+), 49 deletions(-) diff --git a/main.py b/main.py index 85213cb..43a8774 100755 --- a/main.py +++ b/main.py @@ -728,6 +728,33 @@ class MultiEmbedding(nn.Module): return y +def attention_block(dim_model, dim_keys, nb_heads, dropout): + return WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + ), + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + ) + + +def ffw_block(dim_model, dim_hidden, nb_heads, dropout): + return WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ) + + class MyAttentionAE(nn.Module): def __init__( self, @@ -756,40 +783,9 @@ class MyAttentionAE(nn.Module): trunk_blocks = [] for b in range(nb_blocks): - # if b == nb_blocks//2: - # trunk_blocks += [ - # QKVAttention( - # dim_in=dim_model, - # dim_qk=dim_keys, - # dim_v=dim_model // nb_heads, - # nb_heads=nb_heads, - # attention_dropout=dropout, - # ), - # VaswaniPositionalEncoding(len_max=1e5) - # ] - trunk_blocks += [ - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ), - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), + attention_block(dim_model, dim_keys, nb_heads, dropout), + ffw_block(dim_model, dim_hidden, nb_heads, dropout), ] self.trunk = nn.Sequential(*trunk_blocks) @@ -816,6 +812,135 @@ class MyAttentionAE(nn.Module): return bs +###################################################################### + +# f = phi(A, f(A)) + phi(B, f(B)) +# \hat{f(A)} = psi(A, f) +# \hat{A} = psi_inv(f(A), f) +# \hat{f(B)} = psi(B, f) +# \hat{B} = psi_inv(f(B), f) + + +def attention_layer(dim_model, dim_keys, nb_heads, dropout): + return WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + ), + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + ) + + +class FunctionalAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1024, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = CacheWrapper( + nn.Sequential( + MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout) + ), + ) + + # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) + self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) + + def trunk(nb, bottom=True): + trunk_blocks = [] + + la = [ + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + VaswaniPositionalEncoding(len_max=1e5), + ] + + # if not bottom: + # trunk_blocks += la + + for b in range(nb): + trunk_blocks += [ + attention_block(dim_model, dim_keys, nb_heads, dropout), + ffw_block(dim_model, dim_hidden, nb_heads, dropout), + ] + + # if bottom: + # trunk_blocks += la + + return nn.Sequential(*trunk_blocks) + + self.phi = trunk(nb_blocks // 2, bottom=True) + nb_f_tokens = 200 + self.f_tokens = nn.Parameter( + torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens) + ) + self.psi = trunk(nb_blocks // 2, bottom=False) + self.psi_inv = trunk(nb_blocks // 2, bottom=False) + self.internal_pe = VaswaniPositionalEncoding(len_max=1e5) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, bs): + def cat(*x): + return BracketedSequence(torch.cat(x, dim=1)) + + if torch.is_tensor(bs): + return self.forward(BracketedSequence(bs)).x + bs = self.embedding(bs) + bs = self.positional_encoding(bs) + + x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1) + + K = self.f_tokens.size(1) + N, L = x_A.size()[:2] + + ft = self.f_tokens.expand(N, -1, -1) + + theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :] + theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :] + + hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L] + hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L] + + hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L] + hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L] + + bs = cat(hat_A, hat_f_A, hat_B, hat_f_B) + + bs = self.readout(bs) + return bs + + ###################################################################### nb_iterations = 25 @@ -926,19 +1051,25 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50): def model_ae_proba_solutions(model, input): - loss = 0 + record = [] - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=input, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - targets, logits = targets_and_prediction( - probs_iterations, model, input, mask_generate - ) - loss_per_token = F.cross_entropy( - logits.transpose(1, 2), targets, reduction="none" - ) - loss += (loss_per_token * mask_generate).sum(dim=1) + for q in input.split(args.batch_size): + loss = 0 + + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + targets, logits = targets_and_prediction( + probs_iterations, model, q, mask_generate + ) + loss_per_token = F.cross_entropy( + logits.transpose(1, 2), targets, reduction="none" + ) + loss += (loss_per_token * mask_generate).sum(dim=1) + record.append(loss) + + loss = torch.cat(record, dim=0) return (-loss).exp() @@ -1108,7 +1239,8 @@ noise_proba = 0.05 models = [] for i in range(args.nb_models): - model = MyAttentionAE( + # model = MyAttentionAE( + model = FunctionalAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -1169,7 +1301,9 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -for n_epoch in range(args.nb_epochs): +for n_epoch in range(current_epoch, args.nb_epochs): + start_time = time.perf_counter() + state = { "current_epoch": n_epoch, # "total_time_generating_c_quizzes": total_time_generating_c_quizzes, @@ -1187,8 +1321,8 @@ for n_epoch in range(args.nb_epochs): # -------------------------------------------------------------------- - one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) - exit(0) + # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) + # exit(0) ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1231,3 +1365,13 @@ for n_epoch in range(args.nb_epochs): ) log_string(f"wrote {filename}") + + # -------------------------------------------------------------------- + + duration = time.perf_counter() - start_time + str_duration = "" + if duration >= 60: + str_duration += f"{int(duration//60)}min" + duration = duration % 60 + str_duration += f"{duration:.01f}s" + log_string(f"epoch_duration {str_duration}") -- 2.39.5