X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=d6845e8fb3a0f4c123181ea57047ffe10260acfa;hb=a1ae050705970007f965d2586c53e9bd262e46aa;hp=04e56527a51935496c69dbbb6124bd18bc160f28;hpb=e56873a0cb64555cbd47e44cdca0ce991765a5fc;p=mygptrnn.git diff --git a/main.py b/main.py index 04e5652..d6845e8 100755 --- a/main.py +++ b/main.py @@ -87,6 +87,10 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--memex_proba", type=float, default=0) + +parser.add_argument("--memex_nb_epochs", type=float, default=1) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -99,7 +103,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -133,6 +143,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -701,6 +715,8 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -724,6 +740,9 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.memex_proba > 0: + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -741,7 +760,7 @@ model = mygpt.MyGPT( dropout=args.dropout, attention_layer=args.attention, logger=log_string, - **sup_args, + args=args, ) model.to(device) @@ -835,21 +854,22 @@ if args.max_percents_of_test_in_train >= 0: ############################## -for input in task.batches(split="train", desc="calibrate"): - input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x -for n, m in model.named_modules(): - for a in dir(m): - x = getattr(m, a) - if isinstance(x, mygpt.Calibrator): - print(f"####### ${n} | ${a} ########################") - mean, std = x.moments() - print("mean\n", mean, "\n") - print("std\n", std, "\n") - print(f"############################################\n\n") + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") -exit(0) + exit(0) ############################## @@ -884,7 +904,29 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): + def add_memex(batches, memex_proba): + for input in batches: + if torch.rand(1).item() < memex_proba: + sep = torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device + ) + + yield torch.cat( + [ + input, + sep, + input, + ], + dim=1, + ) + yield input + + train_batches = add_memex( + task.batches(split="train"), + args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0, + ) + + for input in train_batches: model.reset_inner_loss() input = input.to(device) @@ -898,7 +940,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it)