######################################################################
+class ModulatedMHAttention(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_qk,
+ dim_v,
+ nb_heads=1,
+ attention=vanilla_attention,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+
+ self.dim_qk = dim_qk
+
+ def randw(*d):
+ return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+ self.attention = attention
+ self.attention_dropout = attention_dropout
+ self.w_q = randw(nb_heads, dim_qk, dim_model)
+ self.w_k = randw(nb_heads, dim_qk, dim_model)
+ self.w_v = randw(nb_heads, dim_v, dim_model)
+ self.w_o = randw(nb_heads, dim_v, dim_model)
+
+ def forward(self, x_q, x_kv=None):
+ modulation, x_q = x_q[:, :, : self.dim_qk], x_q[:, :, self.dim_qk :]
+ if x_kv is None:
+ x_kv = x_q
+
+ q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+ q = q * modulation.sigmoid()
+ k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+ v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+ y = self.attention(q, k, v)
+ y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+ return torch.cat([modulation, y], dim=2)
+
+
+######################################################################
+
+
+class AttentionBlock(nn.Module):
+ def __init__(
+ self,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ ):
+ super().__init__()
+ self.ln1 = nn.LayerNorm((dim_model,))
+ self.mha = MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ )
+ self.ln2 = nn.LayerNorm((dim_model,))
+ self.fc1 = nn.Linear(in_features=dim_model, out_features=dim_hidden)
+ self.fc2 = nn.Linear(in_features=dim_hidden, out_features=dim_model)
+ self.drop_out = nn.Dropout(dropout)
+
+ def forward(self, x):
+ y = self.ln1(x)
+ y = self.mha(y)
+ x = x + y
+ y = self.ln2(x)
+ y = self.fc1(y)
+ y = F.relu(y)
+ y = self.fc2(y)
+ y = self.drop_out(y)
+ x = x + y
+ return x
+
+
def create_trunk(
dim_model,
dim_keys,
dropout=0.0,
residual_masker=None,
):
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- masker=residual_masker,
- f=(
- nn.LayerNorm((dim_model,)),
- MHAttention(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention=vanilla_attention,
- attention_dropout=dropout,
- ),
- ),
- ),
- WithResidual(
- masker=residual_masker,
- f=(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
- ]
+ trunk_blocks = [
+ AttentionBlock(
+ dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0
+ )
+ for _ in range(nb_blocks)
+ ]
return nn.Sequential(*trunk_blocks)
m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
return m[None, :, None]
- for b in range(nb_blocks):
- trunk_blocks += [
- WithMaskedResidual(
- masker,
- nn.LayerNorm((dim_model,)),
- MHAttention(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention=no_peek_attention,
- attention_dropout=dropout,
- ),
- ),
- WithMaskedResidual(
- masker,
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ]
-
- self.trunk = nn.Sequential(*trunk_blocks)
+ self.trunk = nn.Sequential(*[AttentionBlock() for _ in range(nb_blocks)])
self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
self.nb_chunks = nb_chunks
self.x_star = randw(nb_f_tokens, dim_model)
+ with torch.no_grad():
+ self.x_star *= 1e-3
self.positional_encoding = VaswaniPositionalEncoding(len_max)
attention_dropout=attention_dropout,
)
- self.mha_B = MHAttention(
- dim_model=dim_model,
- dim_qk=dim_qk,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention=vanilla_attention,
- attention_dropout=attention_dropout,
- )
+ def forward_(self, x_q):
+ nb, T, dim = x_q.size()
+ nc, S = self.nb_chunks, self.x_star.size(0)
- def forward_AB(self, x_q):
- T, S = x_q.size(1), self.x_star.size(0)
- nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
-
- x = x_q
- x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
- x = self.trunk_A(x)
+ x = x_q.reshape(nb * nc, T // nc, dim)
f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
- f = self.mha_A(f, x)
+ # x = torch.cat([f, x], dim=1)
+ x = self.trunk_A(x)
- k = torch.arange(nb, device=x_q.device)
- u = f[k * 2, :]
- f[k * 2, :] = f[k * 2 + 1, :]
- f[k * 2 + 1, :] = u
+ x_star = self.x_star.reshape(1, S, dim)
+ f = self.mha_A(x_star, x).mean(dim=1, keepdim=True)
- f = self.mha_B(x, f)
- x = self.trunk_B(x)
- x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
+ k = torch.arange(nb * nc, device=x_q.device)
+ k = k + 1 - 2 * (k % 2)
+ f = f[k]
+ # u = x[k * 2, :S]
+ # x[k * 2, :S] = x[k * 2 + 1, :S]
+ # x[k * 2 + 1, :S] = u
+ # x[:, S:] = x_q.reshape(nb * nc, T // nc, dim)
+
+ x = self.trunk_B(x, q_modulation=f)
+
+ x = x[:, S:].reshape(nb, T, dim)
return x
def forward(self, x_q):
T, S = x_q.size(1), self.x_star.size(0)
nb, dim, nc = x_q.size(0), x_q.size(2), self.nb_chunks
- f = self.x_star.reshape(1, S, dim).expand(nb, S, dim)
x = x_q
x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
- f = f.repeat(nc, 1, 1)
- x = torch.cat([f, x], dim=1)
x = self.trunk_A(x)
- k = torch.arange(nb, device=x_q.device)
- u = x[k * 2, :S]
- x[k * 2, :S] = x[k * 2 + 1, :S]
- x[k * 2 + 1, :S] = u
+ # f = self.x_star.reshape(1, S, dim).expand(nb * nc, S, dim)
+ # f = self.mha_A(f, x)
+
+ # if hasattr(self, "f"):
+ # if torch.is_tensor(self.f):
+ # f, self.f = self.f, f
+ # else:
+ # self.f = f
+
+ # k = torch.arange(nb, device=x_q.device)
+ # u = f[k * 2, :]
+ # f[k * 2, :] = f[k * 2 + 1, :]
+ # f[k * 2 + 1, :] = u
+
+ # x = x_q
+ # x = x.reshape(nb, nc, T // nc, dim).reshape(nb * nc, T // nc, dim)
+ # f = self.mha_B(x, f)
x = self.trunk_B(x)
- x = x[:, S:]
x = x.reshape(nb, nc, T // nc, dim).reshape(nb, T, dim)
return x
# Prediction
-def make_imt_samples_for_prediction(input):
+def make_imt_samples_for_prediction(input, u=None):
nb = input.size(0)
masks = input.new_zeros(input.size())
- u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ if u is None:
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
masks.view(nb, 4, -1)[...] = u[:, :, None]
targets = input
input = (1 - masks) * targets
######################################################################
+
+def save_f_token_manipulations(model, n_epoch, local_device):
+ quizzes = generate_quiz_set(256, None, args.c_quiz_multiplier)
+
+ u = F.one_hot(torch.full((quizzes.size(0),), 3, device=local_device), num_classes=4)
+
+ imt_set = make_imt_samples_for_prediction(quizzes.to(local_device), u=u)
+
+ model.eval().to(local_device)
+
+ record = []
+
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
+ )
+
+ N = args.eval_batch_size
+
+ for imt in src:
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+ model.trunk.f = True
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ batch = imt[:, 0] + imt[:, 1] * vocabulary_size
+ logits = model(batch)
+ x = batch[N // 2 : N // 2 + 1].clone()
+ batch[N // 2 :] = batch[: N // 2]
+ batch[: N // 2] = x.expand(N // 2, -1)
+ logits = model(batch)
+
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = dist.sample()
+ record.append(result)
+
+ result = torch.cat(record).to("cpu")
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_f_token_manipulation.png",
+ quizzes=result[:128],
+ nrow=N // 2,
+ )
+
+
if args.test == "aebn":
model = attae.AttentionAE(
vocabulary_size_in=vocabulary_size * 2,
pe, # trainable=True
)
- nb_f_tokens = 200
+ nb_f_tokens = 8
def no_f_residual(x):
m = x.new_full((1, x.size(1), 1), 1.0)
model.nb_epochs = d["nb_epochs"]
log_string(f"successfully loaded {filename} nb_epochs {model.nb_epochs}")
+ save_f_token_manipulations(model, 0, local_device=main_device)
+
else:
for n_epoch in range(args.nb_epochs):
one_complete_epoch(