./main.py --task=expr --nb_blocks=48 --dim_model=1024 --nb_train_samples=2500000 --result_dir=results_expr_48b_d1024_2.5M
======================================================================
+25.07.2023
+
+./main.py --task=sandbox --nb_train_samples=10000 --nb_test_samples=1000 --nb_blocks=4 --nb_heads=1 --nb_epochs=20
# problem,
# problems.ProblemAddition(zero_padded=False, inverted_result=False),
# problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
- problems.ProblemTwoTargets(len_total=12, len_targets=4),
+ problems.ProblemTwoTargets(len_total=16, len_targets=4),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
a1 = s.gather(dim=1, index=k1 + 1 + torch.arange(self.len_targets - 2)[None, :])
a2 = s.gather(dim=1, index=k2 + 1 + torch.arange(self.len_targets - 2)[None, :])
sequences = torch.cat(
- (s, torch.full((nb, 1), 12), a1, torch.full((nb, 1), 12), a2), 1
+ (
+ s,
+ torch.full((nb, 1), 12),
+ a1,
+ torch.full((nb, 1), 12),
+ a2,
+ torch.full((nb, 1), 12),
+ ),
+ 1,
)
ar_mask = (sequences == 12).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join("0123456789+-|"[x.item()] for x in seq)
+ return "".join("0123456789-+|"[x.item()] for x in seq)
####################
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
- if save_attention_image is not None:
+ if save_attention_image is None:
+ logger("no save_attention_image (is pycairo installed?)")
+ else:
for k in range(10):
ns = torch.randint(self.test_input.size(0), (1,)).item()
input = self.test_input[ns : ns + 1].clone()
f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
)
- if save_attention_image is not None:
+ if save_attention_image is None:
+ logger("no save_attention_image (is pycairo installed?)")
+ else:
ns = torch.randint(self.test_input.size(0), (1,)).item()
input = self.test_input[ns : ns + 1].clone()
last = (input != self.t_nul).max(0).values.nonzero().max() + 3