+ def add_memex(batches, memex_proba):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ sep = torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
+ )
+
+ yield torch.cat(
+ [
+ input,
+ sep,
+ input,
+ ],
+ dim=1,
+ )
+ yield input
+
+ train_batches = add_memex(
+ task.batches(split="train"),
+ args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+ )
+
+ for input in train_batches: