X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=ec50722f7f6f8e295563ff6960d45ad9490e4f0a;hb=6d23462ce76c9020dcd7c4bc8a0e7a0fae9b7971;hp=79841f3e518b1854bf3dd21034e3a49096fff49c;hpb=64dc96ddfa84511ba07d1929481e93e864735409;p=mygptrnn.git diff --git a/main.py b/main.py index 79841f3..ec50722 100755 --- a/main.py +++ b/main.py @@ -99,7 +99,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=True) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=True) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -747,7 +753,7 @@ model = mygpt.MyGPT( dropout=args.dropout, attention_layer=args.attention, logger=log_string, - **sup_args, + args=args, ) model.to(device) @@ -905,7 +911,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it)