######################################################################
+class AdHocPositionalEncoding(nn.Module):
+ def __init__(self, dim_model, value, trainable=False):
+ super().__init__()
+ if trainable:
+ self.value = nn.Parameter(value.clone())
+ else:
+ self.register_buffer("value", value.clone())
+ self.fc = nn.Linear(
+ in_features=value.size(-1) + dim_model, out_features=dim_model
+ )
+
+ def forward(self, x):
+ value = self.value[None, :, :].repeat(x.size(0), 1, 1)
+ x = torch.cat([value, x], dim=2)
+ y = self.fc(x)
+ return y
+
+
+######################################################################
+
+
class WithResidual(nn.Module):
def __init__(self, *f):
super().__init__()
problem.save_quizzes_as_image(
args.result_dir,
- f"culture_prediction_{n_epoch}_{model.id}.png",
+ f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
quizzes=result[:128],
predicted_parts=predicted_parts[:128],
correct_parts=correct_parts[:128],
problem.save_quizzes_as_image(
args.result_dir,
- f"culture_generation_{n_epoch}_{model.id}.png",
+ f"culture_generation_{n_epoch:04d}_{model.id:02d}.png",
quizzes=result[:128],
)
len_max=1e4,
)
- model.positional_encoding = attae.BlockRandomPositionalEncoding(
- args.dim_model, 100, 4
- )
+ # model.positional_encoding = attae.BlockRandomPositionalEncoding(
+ # args.dim_model, 100, 4
+ # )
+
+ i = torch.arange(400)[:, None]
+ k = [2**k for k in range(4)] + [10 * 2**k for k in range(4)] + [100, 200]
+ k = torch.tensor(k)[None, :]
+ pe = (i // k) % 2
+
+ model.positional_encoding = attae.AdHocPositionalEncoding(args.dim_model, pe)
model.trunk = attae.Reasoning(
- nb_f_tokens=25,
+ nb_f_tokens=8,
nb_chunks=2,
dim_model=args.dim_model,
dim_qk=args.dim_keys,