X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=ec50722f7f6f8e295563ff6960d45ad9490e4f0a;hb=6d23462ce76c9020dcd7c4bc8a0e7a0fae9b7971;hp=c22ae57ecaaf0dcae71b17d0de84b1ab0941a5cf;hpb=73acbc986f9c386c001117581c4fc72d2f36803a;p=mygptrnn.git diff --git a/main.py b/main.py index c22ae57..ec50722 100755 --- a/main.py +++ b/main.py @@ -99,7 +99,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=True) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=True) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -133,6 +139,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -465,15 +475,15 @@ with os.popen("sha256sum *.py") as f: log_string(f"sha256sum {l.strip()}") now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) -os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh") log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): log_string(f"args.{n} {getattr(args, n)}") -for n in vars(sup_args): - log_string(f"sup_args.{n} {getattr(sup_args, n)}") +for k, v in sup_args.items(): + log_string(f'sup_args["{k}"] "{v}"') ###################################################################### @@ -701,6 +711,8 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -741,7 +753,7 @@ model = mygpt.MyGPT( dropout=args.dropout, attention_layer=args.attention, logger=log_string, - **sup_args, + args=args, ) model.to(device) @@ -835,6 +847,25 @@ if args.max_percents_of_test_in_train >= 0: ############################## +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + + exit(0) + +############################## + nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: @@ -880,7 +911,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it)