From 9fa8d9a54a3802dd8a18d32a29e4649be191d6e4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 13 Oct 2024 22:12:03 +0200 Subject: [PATCH] Update. --- attae.py | 210 +++++++++++++++++++++++++++++++++---------------------- main.py | 59 +++++++++++++++- 2 files changed, 181 insertions(+), 88 deletions(-) diff --git a/attae.py b/attae.py index 94da984..d8c68cc 100755 --- a/attae.py +++ b/attae.py @@ -140,6 +140,86 @@ class MHAttention(nn.Module): ###################################################################### +class ModulatedMHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + self.dim_qk = dim_qk + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + modulation, x_q = x_q[:, :, : self.dim_qk], x_q[:, :, self.dim_qk :] + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + q = q * modulation.sigmoid() + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return torch.cat([modulation, y], dim=2) + + +###################################################################### + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + ): + super().__init__() + self.ln1 = nn.LayerNorm((dim_model,)) + self.mha = MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ) + self.ln2 = nn.LayerNorm((dim_model,)) + self.fc1 = nn.Linear(in_features=dim_model, out_features=dim_hidden) + self.fc2 = nn.Linear(in_features=dim_hidden, out_features=dim_model) + self.drop_out = nn.Dropout(dropout) + + def forward(self, x): + y = self.ln1(x) + y = self.mha(y) + x = x + y + y = self.ln2(x) + y = self.fc1(y) + y = F.relu(y) + y = self.fc2(y) + y = self.drop_out(y) + x = x + y + return x + + def create_trunk( dim_model, dim_keys, @@ -149,35 +229,12 @@ def create_trunk( dropout=0.0, residual_masker=None, ): - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - WithResidual( - masker=residual_masker, - f=( - nn.LayerNorm((dim_model,)), - MHAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention=vanilla_attention, - attention_dropout=dropout, - ), - ), - ), - WithResidual( - masker=residual_masker, - f=( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), - ] + trunk_blocks = [ + AttentionBlock( + dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0 + ) + for _ in range(nb_blocks) + ] return nn.Sequential(*trunk_blocks) @@ -289,31 +346,7 @@ class FunctionalAttentionAE(nn.Module): m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens return m[None, :, None] - for b in range(nb_blocks): - trunk_blocks += [ - WithMaskedResidual( - masker, - nn.LayerNorm((dim_model,)), - MHAttention( - dim_model=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention=no_peek_attention, - attention_dropout=dropout, - ), - ), - WithMaskedResidual( - masker, - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ] - - self.trunk = nn.Sequential(*trunk_blocks) + self.trunk = nn.Sequential(*[AttentionBlock() for _ in range(nb_blocks)]) self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) @@ -360,6 +393,8 @@ class Reasoning(nn.Module): self.nb_chunks = nb_chunks self.x_star = randw(nb_f_tokens, dim_model) + with torch.no_grad(): + self.x_star *= 1e-3 self.positional_encoding = VaswaniPositionalEncoding(len_max) @@ -392,52 +427,57 @@ class Reasoning(nn.Module): attention_dropout=attention_dropout, ) - self.mha_B = MHAttention( - dim_model=dim_model, - dim_qk=dim_qk, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention=vanilla_attention, - attention_dropout=attention_dropout, - ) + def forward_(self, x_q): + nb, T, dim = x_q.size() + nc, S = self.nb_chunks, self.x_star.size(0) - def forward_AB(self, x_q): - T, S = x_q.size(1), self.x_star.size(0) - nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks - - x = x_q - x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) - x = self.trunk_A(x) + x = x_q.reshape(nb * nc, T // nc, dim) f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim) - f = self.mha_A(f, x) + # x = torch.cat([f, x], dim=1) + x = self.trunk_A(x) - k = torch.arange(nb, device=x_q.device) - u = f[k * 2, :] - f[k * 2, :] = f[k * 2 + 1, :] - f[k * 2 + 1, :] = u + x_star = self.x_star.reshape(1, S, dim) + f = self.mha_A(x_star, x).mean(dim=1, keepdim=True) - f = self.mha_B(x, f) - x = self.trunk_B(x) - x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim) + k = torch.arange(nb * nc, device=x_q.device) + k = k + 1 - 2 * (k % 2) + f = f[k] + # u = x[k * 2, :S] + # x[k * 2, :S] = x[k * 2 + 1, :S] + # x[k * 2 + 1, :S] = u + # x[:, S:] = x_q.reshape(nb * nc, T // nc, dim) + + x = self.trunk_B(x, q_modulation=f) + + x = x[:, S:].reshape(nb, T, dim) return x def forward(self, x_q): T, S = x_q.size(1), self.x_star.size(0) nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks - f = self.x_star.reshape(1, S, dim).expand(nb, S, dim) x = x_q x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) - f = f.repeat(nc, 1, 1) - x = torch.cat([f, x], dim=1) x = self.trunk_A(x) - k = torch.arange(nb, device=x_q.device) - u = x[k * 2, :S] - x[k * 2, :S] = x[k * 2 + 1, :S] - x[k * 2 + 1, :S] = u + # f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim) + # f = self.mha_A(f, x) + + # if hasattr(self, "f"): + # if torch.is_tensor(self.f): + # f, self.f = self.f, f + # else: + # self.f = f + + # k = torch.arange(nb, device=x_q.device) + # u = f[k * 2, :] + # f[k * 2, :] = f[k * 2 + 1, :] + # f[k * 2 + 1, :] = u + + # x = x_q + # x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim) + # f = self.mha_B(x, f) x = self.trunk_B(x) - x = x[:, S:] x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim) return x diff --git a/main.py b/main.py index c8d6f10..85b916b 100755 --- a/main.py +++ b/main.py @@ -357,10 +357,11 @@ def add_input_noise_imt(imt_set, proba_input_noise): # Prediction -def make_imt_samples_for_prediction(input): +def make_imt_samples_for_prediction(input, u=None): nb = input.size(0) masks = input.new_zeros(input.size()) - u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) + if u is None: + u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) masks.view(nb, 4, -1)[...] = u[:, :, None] targets = input input = (1 - masks) * targets @@ -917,6 +918,56 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### + +def save_f_token_manipulations(model, n_epoch, local_device): + quizzes = generate_quiz_set(256, None, args.c_quiz_multiplier) + + u = F.one_hot(torch.full((quizzes.size(0),), 3, device=local_device), num_classes=4) + + imt_set = make_imt_samples_for_prediction(quizzes.to(local_device), u=u) + + model.eval().to(local_device) + + record = [] + + src = tqdm.tqdm( + imt_set.split(args.eval_batch_size), + dynamic_ncols=True, + desc="predict", + total=imt_set.size(0) // args.eval_batch_size, + delay=10, + ) + + N = args.eval_batch_size + + for imt in src: + # some paranoia + imt = imt.clone() + imt[:, 0] = imt[:, 0] * (1 - imt[:, 1]) + + model.trunk.f = True + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = imt[:, 0] + imt[:, 1] * vocabulary_size + logits = model(batch) + x = batch[N // 2 : N // 2 + 1].clone() + batch[N // 2 :] = batch[: N // 2] + batch[: N // 2] = x.expand(N // 2, -1) + logits = model(batch) + + dist = torch.distributions.categorical.Categorical(logits=logits) + result = dist.sample() + record.append(result) + + result = torch.cat(record).to("cpu") + + problem.save_quizzes_as_image( + args.result_dir, + f"culture_f_token_manipulation.png", + quizzes=result[:128], + nrow=N // 2, + ) + + if args.test == "aebn": model = attae.AttentionAE( vocabulary_size_in=vocabulary_size * 2, @@ -944,7 +995,7 @@ if args.test == "aebn": pe, # trainable=True ) - nb_f_tokens = 200 + nb_f_tokens = 8 def no_f_residual(x): m = x.new_full((1, x.size(1), 1), 1.0) @@ -981,6 +1032,8 @@ if args.test == "aebn": model.nb_epochs = d["nb_epochs"] log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}") + save_f_token_manipulations(model, 0, local_device=main_device) + else: for n_epoch in range(args.nb_epochs): one_complete_epoch( -- 2.39.5