X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=3aa696b38784c8270c5a840c3e4d5be61dacad2f;hb=c45d89eb5383eedf60466678eae623582bd5781c;hp=79841f3e518b1854bf3dd21034e3a49096fff49c;hpb=8fdce4736a05a37d0f8706148dd743bce123fe1b;p=mygptrnn.git diff --git a/main.py b/main.py index 79841f3..3aa696b 100755 --- a/main.py +++ b/main.py @@ -99,7 +99,11 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=bool, default=False) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -747,7 +751,7 @@ model = mygpt.MyGPT( dropout=args.dropout, attention_layer=args.attention, logger=log_string, - **sup_args, + args=args, ) model.to(device) @@ -905,7 +909,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it)