dim_qk=dim_keys,
dim_v=dim_model // nb_heads,
nb_heads=nb_heads,
+ attention=attention,
attention_dropout=dropout,
),
),
dim_keys,
dim_hidden,
nb_heads,
- nb_work_tokens,
nb_blocks,
+ nb_work_tokens=100,
dropout=0.0,
len_max=1e5,
):
- # def functional_mask(b, h, q_idx, kv_idx):
- # return (
- # (q_idx < nb_work_tokens)
- # | (kv_idx < nb_work_tokens)
- # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200)
- # )
-
- # block_mask = create_block_mask(
- # functional_mask,
- # B=None,
- # H=None,
- # Q_LEN=400 + nb_work_tokens,
- # KV_LEN=400 + nb_work_tokens,
- # )
-
- # def functional_attention(q, k, v):
- # return flex_attention(q, k, v, block_mask=block_mask)
+ def no_peek_attention(q, k, v):
+ a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+ n = self.nb_work_tokens
+ s = (q.size(2) - n) // 2
+ a[:, :, n + 0 * s : n + 1 * s, n + 0 * s : n + 1 * s] = float("-inf")
+ a[:, :, n + 1 * s : n + 2 * s, n + 1 * s : n + 2 * s] = float("-inf")
+ a = a.softmax(dim=3)
+ y = torch.einsum("nhts,nhsd->nhtd", a, v)
+ return y
AttentionAE.__init__(
self,
dim_hidden,
nb_heads,
nb_blocks,
+ attention=no_peek_attention,
dropout=0.0,
len_max=1e5,
)
def log_string(s):
+ """print the given string prefixed with a time stamps, and log it into log_file is not None"""
+
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
######################################################################
-
-# IMT for input / masks / target
-
-# Generate a batch for prediction
+# Prediction
-def batch_for_prediction_imt(input):
+def samples_for_prediction_imt(input):
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
######################################################################
-def batch_for_generation_imt(input):
+def samples_for_generation_imt(input):
nb = input.size(0)
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.diffusion_nb_iterations, device=input.device
# Half of the samples train the prediction, and we inject noise in
# all, and hints in half
- b_p = batch_for_prediction_imt(q_p)
+ b_p = samples_for_prediction_imt(q_p)
b_p = add_noise_imt(b_p)
half = torch.rand(b_p.size(0)) < 0.5
b_p[half] = add_hints_imt(b_p[half])
# The other half are denoising examples for the generation
- b_g = batch_for_generation_imt(q_g)
+ b_g = samples_for_generation_imt(q_g)
imt_set = torch.cat([b_p, b_g])
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
quizzes = quiz_machine.quiz_set(
args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
)
- imt_set = batch_for_prediction_imt(quizzes.to(local_device))
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
models = []
for i in range(args.nb_models):
- # model = attae.FunctionalAttentionAE(
- # vocabulary_size=vocabulary_size * 2,
- # dim_model=args.dim_model,
- # dim_keys=args.dim_keys,
- # dim_hidden=args.dim_hidden,
- # nb_heads=args.nb_heads,
- # nb_blocks=args.nb_blocks,
- # nb_work_tokens=10,
- # dropout=args.dropout,
- # )
-
- model = attae.AttentionAE(
+ model = attae.FunctionalAttentionAE(
+ # model = attae.AttentionAE(
vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,