X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=dace5f2e2276ff4ba2472b90677b253dea58a46e;hb=798d9526e726b644979cf1124e714f705fdd5966;hp=9437136ce1a45b066d6884e205540083bfb4d2d6;hpb=3528c66810984055a0e0f0cf7a4169c3340be0c8;p=picoclvr.git diff --git a/main.py b/main.py index 9437136..dace5f2 100755 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, os, datetime +import math, sys, argparse, time, tqdm, os, datetime, warnings import torch, torchvision from torch import nn @@ -46,10 +46,12 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) ######################################## -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--physical_batch_size", type=int, default=None) + parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) @@ -82,7 +84,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--no_checkpoint", action="store_true", default=False) -parser.add_argument("--overwrite_results", action="store_true", default=False) +parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") @@ -144,6 +146,11 @@ parser.add_argument("--snake_nb_colors", type=int, default=5) parser.add_argument("--snake_length", type=int, default=200) +############################## +# ByHeart options + +parser.add_argument("--byheart_separation", type=int, default=1) + ############################## # Stack options @@ -153,7 +160,7 @@ parser.add_argument("--stack_nb_stacks", type=int, default=3) parser.add_argument("--stack_nb_digits", type=int, default=3) -parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75) +parser.add_argument("--stack_fraction_values_for_train", type=float, default=None) ############################## # Expr options @@ -367,7 +374,7 @@ else: try: os.mkdir(args.result_dir) except FileExistsError: - if not args.overwrite_results: + if not args.resume: print(f"result directory {args.result_dir} already exists") exit(1) @@ -422,6 +429,14 @@ picoclvr_pruner_eval = ( ###################################################################### +if args.physical_batch_size is None: + args.physical_batch_size = args.batch_size +else: + assert args.batch_size % args.physical_batch_size == 0 + +assert args.nb_train_samples % args.batch_size == 0 +assert args.nb_test_samples % args.batch_size == 0 + if args.task == "file": assert ( args.filetask_train_file is not None and args.filetask_test_file is not None @@ -431,7 +446,7 @@ if args.task == "file": args.filetask_test_file, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, shuffle=True, device=device, ) @@ -439,10 +454,10 @@ if args.task == "file": elif args.task == "byheart": task = tasks.SandBox( - problem=problems.ProblemByHeart(), + problem=problems.ProblemByHeart(separation=args.byheart_separation), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -453,7 +468,7 @@ elif args.task == "learnop": problem=problems.ProblemLearnOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -464,7 +479,7 @@ elif args.task == "guessop": problem=problems.ProblemGuessOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -475,7 +490,7 @@ elif args.task == "twotargets": problem=problems.ProblemTwoTargets(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -485,7 +500,7 @@ elif args.task == "memory": problem=problems.ProblemMemory(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -497,7 +512,7 @@ elif args.task == "mixing": ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -507,7 +522,7 @@ elif args.task == "addition": problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -516,7 +531,7 @@ elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.picoclvr_height, width=args.picoclvr_width, nb_colors=args.picoclvr_nb_colors, @@ -530,7 +545,7 @@ elif args.task == "mnist": task = tasks.MNIST( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device, ) @@ -538,18 +553,18 @@ elif args.task == "maze": task = tasks.Maze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.maze_height, width=args.maze_width, nb_walls=args.maze_nb_walls, - device=device, + device="cpu", ) elif args.task == "snake": task = tasks.Snake( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.snake_height, width=args.snake_width, nb_colors=args.snake_nb_colors, @@ -562,7 +577,7 @@ elif args.task == "stack": task = tasks.Stack( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, @@ -579,7 +594,7 @@ elif args.task == "expr": sequence_length=args.expr_sequence_length, operand_max=args.expr_operand_max, result_max=args.expr_result_max, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device, ) @@ -587,7 +602,7 @@ elif args.task == "rpl": task = tasks.RPL( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, nb_starting_values=args.rpl_nb_starting_values, max_input=args.rpl_max_input, prog_len=args.rpl_prog_len, @@ -601,7 +616,7 @@ elif args.task == "grid": task = tasks.Grid( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, size=args.grid_size, fraction_play=args.grid_fraction_play, logger=log_string, @@ -612,7 +627,7 @@ elif args.task == "qmlp": task = tasks.QMLP( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, device=device, @@ -622,7 +637,7 @@ elif args.task == "greed": task = tasks.Greed( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.greed_height, width=args.greed_width, T=args.greed_T, @@ -773,8 +788,6 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") ############################## -nb_samples_seen = 0 - if nb_epochs_finished >= args.nb_epochs: task.produce_results( n_epoch=nb_epochs_finished, @@ -802,40 +815,77 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): model.train() - nb_train_samples, acc_train_loss = 0, 0.0 + nb_train_samples, acc_train_loss_ar, acc_train_loss_ae = 0, 0.0, 0.0 for input in task.batches(split="train"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + if args.autoencoder_weight > 0: + bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) + output_ar, output_ae = bs_ar.x, bs_ae.x + loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) + loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1]) + else: + output = model(mygpt.BracketedSequence(input)).x + loss_ar = F.cross_entropy(output.transpose(1, 2), input) + loss_ae = loss_ar.new_full((1,), 0.0) + + acc_train_loss_ar += loss_ar.item() * input.size(0) + acc_train_loss_ae += loss_ae.item() * input.size(0) + nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) - optimizer.zero_grad() - loss.backward() - optimizer.step() + (loss_ar + args.autoencoder_weight * loss_ae).backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() with torch.autograd.no_grad(): model.eval() - nb_test_samples, acc_test_loss = 0, 0.0 + nb_test_samples, acc_test_loss_ar, acc_test_loss_ae = 0, 0.0, 0.0 + nb_samples_accumulated = 0 for input in task.batches(split="test"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_test_loss += loss.item() * input.size(0) + if args.autoencoder_weight > 0: + bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) + output_ar, output_ae = bs_ar.x, bs_ae.x + loss_ae = F.cross_entropy( + output_ae[:, 1:].transpose(1, 2), input[:, :-1] + ) + acc_test_loss_ae += loss_ae.item() * input.size(0) + else: + bs_ar = model(mygpt.BracketedSequence(input)) + output_ar = bs_ar.x + + loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) + + acc_test_loss_ar += loss_ar.item() * input.size(0) + nb_test_samples += input.size(0) - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + train_ar_perplexity = math.exp(min(100, acc_train_loss_ar / nb_train_samples)) + test_ar_perplexity = math.exp(min(100, acc_test_loss_ar / nb_test_samples)) log_string( - f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + f"perplexity_ar {n_epoch} train_set {train_set_perplexity} train_prediction {train_ar_perplexity} test_prediction {test_ar_perplexity}" ) + if args.autoencoder_weight > 0: + train_ae_perplexity = math.exp( + min(100, acc_train_loss_ae / nb_train_samples) + ) + test_ae_perplexity = math.exp(min(100, acc_test_loss_ae / nb_test_samples)) + + log_string( + f"perplexity_ae {n_epoch} train_set {train_set_perplexity} train_prediction {train_ae_perplexity} test_prediction {test_ae_perplexity}" + ) + task.produce_results( n_epoch=n_epoch, model=model,