X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=4d5077ae1b46b7398aed253c14129d5f6b451879;hb=a3c32b845b6903fd290f2b09d5c53203ff112b79;hp=79841f3e518b1854bf3dd21034e3a49096fff49c;hpb=64dc96ddfa84511ba07d1929481e93e864735409;p=mygptrnn.git diff --git a/main.py b/main.py index 79841f3..4d5077a 100755 --- a/main.py +++ b/main.py @@ -87,6 +87,8 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--proportion_memex", type=float, default=0) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -99,7 +101,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -730,6 +738,9 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.proportion_memex > 0: + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -747,7 +758,7 @@ model = mygpt.MyGPT( dropout=args.dropout, attention_layer=args.attention, logger=log_string, - **sup_args, + args=args, ) model.to(device) @@ -891,7 +902,24 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): + def add_memex(batches, proportion_memex): + for input in batches: + if torch.rand(1).item() < proportion_memex: + yield torch.cat( + [ + input, + torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device + ), + input, + ], + dim=1, + ) + yield input + + train_batches = add_memex(task.batches(split="train"), args.proportion_memex) + + for input in train_batches: model.reset_inner_loss() input = input.to(device) @@ -905,7 +933,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it)