dim_hidden,
nb_heads,
nb_blocks,
- attention=vanilla_attention,
dropout=0.0,
len_max=1e5,
):
dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
- attention=attention,
+ attention=vanilla_attention,
attention_dropout=dropout,
),
),
######################################################################
-class FunctionalAttentionAE(AttentionAE):
+class WithMaskedResidual(nn.Module):
+ def __init__(self, masker, *f):
+ super().__init__()
+ self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+ self.masker = masker
+ self.mask = None
+
+ def forward(self, x):
+ if self.mask is None:
+ self.mask = self.masker(x)
+ return self.mask * x + self.f(x)
+
+
+######################################################################
+
+
+class FunctionalAttentionAE(nn.Module):
def __init__(
self,
vocabulary_size,
dropout=0.0,
len_max=1e5,
):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.nb_work_tokens = nb_work_tokens
+
+ self.embedding = nn.Sequential(
+ nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Dropout(dropout),
+ )
+
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
def no_peek_attention(q, k, v):
a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
n = self.nb_work_tokens
y = torch.einsum("nhts,nhsd->nhtd", a, v)
return y
- AttentionAE.__init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- attention=no_peek_attention,
- dropout=0.0,
- len_max=1e5,
- )
- self.nb_work_tokens = nb_work_tokens
+ def masker(x):
+ m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens
+ return m[None, :, None]
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=no_peek_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithMaskedResidual(
+ masker,
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
def forward(self, x):
- x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1)
- return AttentionAE.forward(self, x)[:, self.nb_work_tokens :]
+ x = self.embedding(x)
+ x = F.pad(x, (0, 0, self.nb_work_tokens, 0))
+ x = self.positional_encoding(x)
+ x = self.trunk(x)
+ x = F.pad(x, (0, 0, -self.nb_work_tokens, 0))
+ x = self.readout(x)
+ return x
######################################################################
class Grids(problem.Problem):
- grid_gray = 64
- thickness = 1
- background_gray = 255
- dots = False
-
- # grid_gray=240
- # thickness=1
- # background_gray=240
+ # grid_gray = 64
+ # thickness = 1
+ # background_gray = 255
# dots = False
+ grid_gray = 240
+ thickness = 0
+ background_gray = 240
+ dots = False
+
# grid_gray = 192
# thickness = 0
# background_gray = 255
def vocabulary_size(self):
warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning)
- return self.nb_colors + 4
+ return self.nb_colors
def grid2img(self, x, scale=15, grids=True):
m = torch.logical_and(x >= 0, x < self.nb_colors).long()
grids=True,
margin=12,
delta=False,
+ delta_highlight=False,
):
quizzes = quizzes.to("cpu")
self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
)
+ if delta_highlight:
+ q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long()
+ img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B
+
# predicted_parts Nx4
# correct_parts Nx4
"/tmp",
t.__name__ + ".png",
w_quizzes,
+ delta=True,
# grids=False
# comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
)
# ----------------------------------
+parser.add_argument("--model_type", type=str, default="standard")
+
parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
models = []
+if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+
for i in range(args.nb_models):
- # model = attae.FunctionalAttentionAE(
- model = attae.AttentionAE(
+ model = model_constructor(
vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,