from torch import nn
from torch.nn import functional as F
-# from torch.nn.attention.flex_attention import flex_attention
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
######################################################################
######################################################################
-def attention(q, k, v):
+def vanilla_attention(q, k, v):
a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
a = a.softmax(dim=3)
y = torch.einsum("nhts,nhsd->nhtd", a, v)
dim_qk,
dim_v,
nb_heads=1,
+ attention=vanilla_attention,
attention_dropout=0.0,
):
super().__init__()
def randw(*d):
return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+ self.attention = attention
self.attention_dropout = attention_dropout
self.w_q = randw(nb_heads, dim_qk, dim_model)
self.w_k = randw(nb_heads, dim_qk, dim_model)
q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
- y = attention(q, k, v)
+ y = self.attention(q, k, v)
y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
return y
dim_hidden,
nb_heads,
nb_blocks,
+ attention=vanilla_attention,
dropout=0.0,
len_max=1e5,
):
######################################################################
-class MaskedAttentionAE(nn.Module):
+class FunctionalAttentionAE(AttentionAE):
def __init__(
self,
vocabulary_size,
dim_keys,
dim_hidden,
nb_heads,
+ nb_work_tokens,
nb_blocks,
dropout=0.0,
len_max=1e5,
):
- super().__init__()
- self.core = AttentionAE(
- vocabulary_size * 2,
+ # def functional_mask(b, h, q_idx, kv_idx):
+ # return (
+ # (q_idx < nb_work_tokens)
+ # | (kv_idx < nb_work_tokens)
+ # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200)
+ # )
+
+ # block_mask = create_block_mask(
+ # functional_mask,
+ # B=None,
+ # H=None,
+ # Q_LEN=400 + nb_work_tokens,
+ # KV_LEN=400 + nb_work_tokens,
+ # )
+
+ # def functional_attention(q, k, v):
+ # return flex_attention(q, k, v, block_mask=block_mask)
+
+ AttentionAE.__init__(
+ self,
+ vocabulary_size,
dim_model,
dim_keys,
dim_hidden,
dropout=0.0,
len_max=1e5,
)
+ self.nb_work_tokens = nb_work_tokens
def forward(self, x):
- x = x[:, :, 0] * 2 + x[:, :, 1]
- return self.core(x)
+ x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1)
+ return AttentionAE.forward(self, x)[:, self.nb_work_tokens :]
######################################################################
if __name__ == "__main__":
- model = AttentionAE(
+ model = FunctionalAttentionAE(
vocabulary_size=100,
dim_model=16,
dim_keys=64,
dim_hidden=32,
nb_heads=4,
+ nb_work_tokens=10,
nb_blocks=4,
dropout=0.1,
)
class Grids(problem.Problem):
+ # grid_gray=64
+ # thickness=1
+ # background_gray=255
+
+ grid_gray = 255
+ thickness = 0
+ background_gray = grid_gray
+
named_colors = [
- ("white", [255, 255, 255]),
+ ("white", [background_gray, background_gray, background_gray]),
# ("white", [224, 224, 224]),
("red", [255, 0, 0]),
("green", [0, 192, 0]),
y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
if grids:
- y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
- y[:, :, torch.arange(0, y.size(2), scale), :] = 64
+ for t in range(self.thickness):
+ y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+ y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
for n in range(m.size(0)):
for i in range(m.size(1)):
)
frame, white, gray, green, red = torch.tensor(
- [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
+ [
+ [self.grid_gray, self.grid_gray, self.grid_gray],
+ [255, 255, 255],
+ [200, 200, 200],
+ [0, 255, 0],
+ [255, 0, 0],
+ ],
device=quizzes.device,
)
- thickness = 1 if grids else 0
+ thickness = self.thickness
if delta:
u = (A != f_A).long()
# Half of the samples train the prediction, and we inject noise in
# all, and hints in half
b_p = batch_for_prediction_imt(q_p)
- i = torch.rand(b_p.size(0)) < 0.5
b_p = add_noise_imt(b_p)
- b_p[i] = add_hints_imt(b_p[i])
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
# The other half are denoising examples for the generation
b_g = batch_for_generation_imt(q_g)
# Compute the test accuracy
- nb_correct, nb_total = correct.sum(), quizzes.size(0)
+ nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
model.test_accuracy = nb_correct / nb_total
log_string(
models = []
for i in range(args.nb_models):
+ # model = attae.FunctionalAttentionAE(
+ # vocabulary_size=vocabulary_size * 2,
+ # dim_model=args.dim_model,
+ # dim_keys=args.dim_keys,
+ # dim_hidden=args.dim_hidden,
+ # nb_heads=args.nb_heads,
+ # nb_blocks=args.nb_blocks,
+ # nb_work_tokens=10,
+ # dropout=args.dropout,
+ # )
+
model = attae.AttentionAE(
vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
+ log_string(
+ f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+ )
+
multithread_execution(
one_complete_epoch,
[(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],