parser.add_argument("--attention", type=str, default=None)
+parser.add_argument("--proportion_memex", type=float, default=0)
+
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dim_keys", type=int, default=None)
parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
-parser.add_argument("--gate_dropout_sync", type=str2bool, default=True)
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
-parser.add_argument("--gate_dropout_replace", type=str2bool, default=True)
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
parser.add_argument("--rho_inner_loss", type=float, default=0.0)
vocabulary_size = task.vocabulary_size()
+if args.proportion_memex > 0:
+ vocabulary_size += 1
+
log_string(f"vocabulary_size {vocabulary_size}")
##############################
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- for input in task.batches(split="train"):
+ def add_memex(batches, proportion_memex):
+ for input in batches:
+ if torch.rand(1).item() < proportion_memex:
+ yield torch.cat(
+ [
+ input,
+ torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
+ ),
+ input,
+ ],
+ dim=1,
+ )
+ yield input
+
+ train_batches = add_memex(task.batches(split="train"), args.proportion_memex)
+
+ for input in train_batches:
model.reset_inner_loss()
input = input.to(device)