X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=625480798423c770c80dc3c353644d2566e4cc99;hb=8012a611e9920816fe6ba382b69305242136bc2a;hp=3aa696b38784c8270c5a840c3e4d5be61dacad2f;hpb=c45d89eb5383eedf60466678eae623582bd5781c;p=mygptrnn.git diff --git a/main.py b/main.py index 3aa696b..6254807 100755 --- a/main.py +++ b/main.py @@ -11,6 +11,8 @@ import torch, torchvision from torch import nn from torch.nn import functional as F +# torch.autograd.set_detect_anomaly(True) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + import ffutils import mygpt, tasks, problems @@ -51,9 +53,11 @@ parser.add_argument("--force_cpu", type=str2bool, default=False) ######################################## -parser.add_argument("--nb_epochs", type=int, default=50) +parser.add_argument("--nb_epochs", type=int, default=25) + +parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=None) @@ -87,6 +91,10 @@ parser.add_argument("--model", type=str, default=None) parser.add_argument("--attention", type=str, default=None) +parser.add_argument("--memex_proba", type=float, default=0) + +parser.add_argument("--memex_nb_epochs", type=float, default=None) + parser.add_argument("--dim_model", type=int, default=None) parser.add_argument("--dim_keys", type=int, default=None) @@ -101,7 +109,9 @@ parser.add_argument("--caterpillar_height", type=int, default=None) parser.add_argument("--gate_dropout_proba", type=float, default=0.0) -parser.add_argument("--gate_dropout_sync", type=bool, default=False) +parser.add_argument("--gate_dropout_sync", type=str2bool, default=False) + +parser.add_argument("--gate_dropout_replace", type=str2bool, default=False) parser.add_argument("--rho_inner_loss", type=float, default=0.0) @@ -232,97 +242,97 @@ else: default_task_args = { "addition": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "byheart": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "expr": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "grid": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "qmlp": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 10, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "guessop": { "model": "352M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 1000000, "nb_test_samples": 10000, }, "learnop": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "maze": { "model": "37M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 100000, "nb_test_samples": 10000, }, "picoclvr": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "rpl": { "model": "352M", - "batch_size": 5, + "physical_batch_size": 5, "nb_train_samples": 2500000, "nb_test_samples": 10000, }, "snake": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "stack": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 100000, "nb_test_samples": 1000, }, "twotargets": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 50000, "nb_test_samples": 10000, }, "memory": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 25000, "nb_test_samples": 10000, }, "mixing": { "model": "37M", - "batch_size": 25, + "physical_batch_size": 25, "nb_train_samples": 250000, "nb_test_samples": 10000, }, "mnist": { "model": "37M", - "batch_size": 10, + "physical_batch_size": 5, "nb_train_samples": 60000, "nb_test_samples": 10000, }, @@ -520,6 +530,90 @@ def get_lr(n_epoch, it): ###################################################################### +def add_memex_v2(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(1 + 2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + +def add_memex_v3(batches, memex_proba, marker_token): + for input in batches: + if torch.rand(1).item() < memex_proba: + t = ( + torch.arange(2 * input.size(1), device=input.device)[None, :] + .expand(input.size(0), -1) + .clone() + ) + + u = torch.rand(t.size(), device=t.device) + u[:, : input.size(1)] = 1.0 + memex_v3_proba_fragment = 1 / 20 + u = (u < memex_v3_proba_fragment).long() + v = u * torch.randint(input.size(1), u.size()) + u[:, input.size(1) + 1 :] = v[:, input.size(1) + 1 :] - u[ + :, : input.size(1) - 1 + ] * input.size(1) + u = u.cumsum().clamp(min=0) + + u0 = torch.randint(input.size(1), (input.size(0), 1), device=input.device) + caterpillar_length = args.nb_lines // args.caterpillar_height + u1 = ( + u0 + + torch.randint( + caterpillar_length, (input.size(0), 1), device=input.device + ) + + 1 + ) + + m0 = (t < u0).long() + m1 = (t >= u1).long() * (t < u1 + input.size(1)).long() + + t = t * m0 + ((-1) * (1 - m0) * (1 - m1)) + (t - u1) * m1 + m = (t < 0).long() + n = torch.arange(input.size(0), device=input.device)[:, None].expand( + -1, t.size(1) + ) + + new_input = input[n, t.clamp(min=0)] + new_input = (1 - m) * new_input + m * (marker_token) + + yield new_input + + yield input + + +###################################################################### + assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"} @@ -548,7 +642,7 @@ if args.task == "byheart": problem=problems.ProblemByHeart(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -559,7 +653,7 @@ elif args.task == "learnop": problem=problems.ProblemLearnOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -570,7 +664,7 @@ elif args.task == "guessop": problem=problems.ProblemGuessOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -581,7 +675,7 @@ elif args.task == "twotargets": problem=problems.ProblemTwoTargets(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -591,7 +685,7 @@ elif args.task == "memory": problem=problems.ProblemMemory(len_total=args.memory_len_total), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -603,7 +697,7 @@ elif args.task == "mixing": ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -613,7 +707,7 @@ elif args.task == "addition": problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device_data, ) @@ -622,7 +716,7 @@ elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.picoclvr_height, width=args.picoclvr_width, nb_colors=args.picoclvr_nb_colors, @@ -636,7 +730,7 @@ elif args.task == "mnist": task = tasks.MNIST( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -644,7 +738,7 @@ elif args.task == "maze": task = tasks.Maze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.maze_height, width=args.maze_width, nb_walls=args.maze_nb_walls, @@ -655,7 +749,7 @@ elif args.task == "snake": task = tasks.Snake( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.snake_height, width=args.snake_width, nb_colors=args.snake_nb_colors, @@ -668,7 +762,7 @@ elif args.task == "stack": task = tasks.Stack( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, @@ -685,7 +779,7 @@ elif args.task == "expr": sequence_length=args.expr_sequence_length, operand_max=args.expr_operand_max, result_max=args.expr_result_max, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device_data, ) @@ -693,7 +787,7 @@ elif args.task == "rpl": task = tasks.RPL( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, nb_starting_values=args.rpl_nb_starting_values, max_input=args.rpl_max_input, prog_len=args.rpl_prog_len, @@ -707,7 +801,7 @@ elif args.task == "grid": task = tasks.Grid( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, size=args.grid_size, nb_shapes=args.grid_nb_shapes, nb_colors=args.grid_nb_colors, @@ -719,7 +813,7 @@ elif args.task == "qmlp": task = tasks.QMLP( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, device=device_data, @@ -734,6 +828,9 @@ log_string(f"device {device}") vocabulary_size = task.vocabulary_size() +if args.memex_proba > 0: + vocabulary_size += 1 + log_string(f"vocabulary_size {vocabulary_size}") ############################## @@ -895,38 +992,63 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0 - for input in task.batches(split="train"): - model.reset_inner_loss() - input = input.to(device) + memex_proba = ( + args.memex_proba + if args.memex_nb_epochs is None or n_epoch < args.memex_nb_epochs + else 0.0 + ) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - inner_loss = model.get_inner_loss() + log_string(f"memex_proba {memex_proba}") + + train_batches = add_memex_v2( + batches=task.batches(split="train"), + memex_proba=memex_proba, + marker_token=vocabulary_size - 1, + ) - acc_train_loss += loss.item() * input.size(0) - acc_train_inner_loss += inner_loss.item() * input.size(0) + def add_none(it): + for x in it: + yield x + yield None - nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) + nb_acc_samples = 0 - total_loss = loss + ( - args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 - ) + for input in add_none(train_batches): + if input is not None: + model.reset_inner_loss() + input = input.to(device) + + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + inner_loss = model.get_inner_loss() + + acc_train_loss += loss.item() * input.size(0) + acc_train_inner_loss += inner_loss.item() * input.size(0) + + nb_train_samples += input.size(0) + nb_samples_seen += input.size(0) - it += 1 - lr = get_lr(n_epoch, it) - for param_group in optimizer.param_groups: - param_group["lr"] = lr + total_loss = loss + ( + args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0 + ) - # log_string(f"learning_rate {lr}") + it += 1 + lr = get_lr(n_epoch, it) + for param_group in optimizer.param_groups: + param_group["lr"] = lr - optimizer.zero_grad() - total_loss.backward() - optimizer.step() + # log_string(f"learning_rate {lr}") - grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + total_loss.backward() + nb_acc_samples += input.size(0) - loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + if (input is None and nb_acc_samples > 0) or nb_acc_samples == args.batch_size: + assert nb_acc_samples <= args.batch_size + optimizer.step() + grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt() + loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n") + optimizer.zero_grad() + nb_acc_samples = 0 n_batch += 1