X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=91c885b5177619ab33d6b62ec7c87825ab1d2543;hb=7b2d37d9c7ffb10f9bd81ef6356ef7083614a380;hp=c51035c118f6480f2f33560b8aeea020c4cfbf28;hpb=ffe183868ac8563fd82fc8312fda90f6f8a95833;p=mygptrnn.git diff --git a/main.py b/main.py index c51035c..91c885b 100755 --- a/main.py +++ b/main.py @@ -16,14 +16,6 @@ import mygpt, tasks, problems ###################################################################### -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") - -###################################################################### - def str2bool(x): x = x.lower() @@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--force_cpu", type=str2bool, default=False) + ######################################## parser.add_argument("--nb_epochs", type=int, default=50) @@ -93,6 +87,8 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--proba_memex", type=float, default=0) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -105,7 +101,13 @@ parser.add_argument("--nb_lines", type=int, default=None) parser.add_argument("--caterpillar_height", type=int, default=None) -parser.add_argument("--rho", type=float, default=0.0) +parser.add_argument("--gate_dropout_proba", type=float, default=0.0) + +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) + +parser.add_argument("--rho_inner_loss", type=float, default=0.0) parser.add_argument("--nb_blocks", type=int, default=None) @@ -139,6 +141,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False) parser.add_argument("--grid_size", type=int, default=6) +parser.add_argument("--grid_nb_colors", type=int, default=6) + +parser.add_argument("--grid_nb_shapes", type=int, default=6) + ############################## # picoclvr options @@ -208,15 +214,25 @@ parser.add_argument("--mixing_deterministic_start", action="store_true", default ###################################################################### -args = parser.parse_args() +# args = parser.parse_args() -assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} +args, sup_args = parser.parse_known_args() + +sup_args = dict([x.removeprefix("--").split("=") for x in sup_args]) if args.result_dir is None: args.result_dir = f"results_{args.task}_{args.model}" ###################################################################### +if not args.force_cpu and torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") + +###################################################################### + default_task_args = { "addition": { "model": "352M", @@ -430,6 +446,8 @@ except FileExistsError: print(f"result directory {args.result_dir} already exists") exit(1) +loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a") + log_file = open(os.path.join(args.result_dir, args.log_filename), "a") if args.seed >= 0: @@ -466,6 +484,9 @@ log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): log_string(f"args.{n} {getattr(args, n)}") +for k, v in sup_args.items(): + log_string(f'sup_args["{k}"] "{v}"') + ###################################################################### @@ -503,6 +524,9 @@ def get_lr(n_epoch, it): ###################################################################### +assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} + + def picoclvr_pruner_horizontal_green(p): return not ("green" in p and ("left" in p or "right" in p)) @@ -689,6 +713,8 @@ elif args.task == "grid": nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, size=args.grid_size, + nb_shapes=args.grid_nb_shapes, + nb_colors=args.grid_nb_colors, logger=log_string, device=device_data, ) @@ -712,6 +738,9 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.proba_memex > 0: + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -728,6 +757,8 @@ model = mygpt.MyGPT( causal=True, dropout=args.dropout, attention_layer=args.attention, + logger=log_string, + args=args, ) model.to(device) @@ -821,6 +852,25 @@ if args.max_percents_of_test_in_train >= 0: ############################## +if "calibrate" in sup_args: + for input in task.batches(split="train", desc="calibrate"): + input = input.to(device) + output = model(mygpt.BracketedSequence(input)).x + + for n, m in model.named_modules(): + for a in dir(m): + x = getattr(m, a) + if isinstance(x, mygpt.Calibrator): + print(f"####### ${n} | ${a} ########################") + mean, std = x.moments() + print("mean\n", mean, "\n") + print("std\n", std, "\n") + print(f"############################################\n\n") + + exit(0) + +############################## + nb_samples_seen = 0 if nb_epochs_finished >= nb_epochs: @@ -832,10 +882,12 @@ if nb_epochs_finished >= nb_epochs: deterministic_synthesis=args.deterministic_synthesis, ) -time_pred_result = None +time_pred_result = datetime.datetime.now() it = 0 +n_batch = 0 + for n_epoch in range(nb_epochs_finished, nb_epochs): if args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) @@ -850,7 +902,28 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): + def add_memex(batches, proba_memex): + for input in batches: + if torch.rand(1).item() < proba_memex: + sep = ( + torch.full( + (input.size(0), 1), vocabulary_size - 1, device=input.device + ), + ) + + yield torch.cat( + [ + input, + sep, + input, + ], + dim=1, + ) + yield input + + train_batches = add_memex(task.batches(split="train"), args.proba_memex) + + for input in train_batches: model.reset_inner_loss() input = input.to(device) @@ -864,7 +937,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples += input.size(0) nb_samples_seen += input.size(0) - total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0) + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) it += 1 lr = get_lr(n_epoch, it) @@ -877,6 +952,12 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): total_loss.backward() optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + + n_batch += 1 + with torch.autograd.no_grad(): model.eval() @@ -910,10 +991,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): ) time_current_result = datetime.datetime.now() - if time_pred_result is not None: - log_string( - f"next_result {time_current_result + (time_current_result - time_pred_result)}" - ) + log_string( + f"next_result {time_current_result + (time_current_result - time_pred_result)}" + ) time_pred_result = time_current_result checkpoint = {