parser.add_argument("--attention", type=str, default=None)
-parser.add_argument("--proportion_memex", type=float, default=0)
+parser.add_argument("--proba_memex", type=float, default=0)
parser.add_argument("--dim_model", type=int, default=None)
vocabulary_size = task.vocabulary_size()
-if args.proportion_memex > 0:
+if args.proba_memex > 0:
vocabulary_size += 1
log_string(f"vocabulary_size {vocabulary_size}")
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- def add_memex(batches, proportion_memex):
+ def add_memex(batches, proba_memex):
for input in batches:
- if torch.rand(1).item() < proportion_memex:
+ if torch.rand(1).item() < proba_memex:
+ sep = (
+ torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
+ ),
+ )
+
yield torch.cat(
[
input,
- torch.full(
- (input.size(0), 1), vocabulary_size - 1, device=input.device
- ),
+ sep,
input,
],
dim=1,
)
yield input
- train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
+ train_batches = add_memex(task.batches(split="train"), args.proba_memex)
for input in train_batches:
model.reset_inner_loss()