From 798d9526e726b644979cf1124e714f705fdd5966 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 13 Jun 2024 19:40:17 +0200 Subject: [PATCH] Update. --- main.py | 134 ++++++++++++++++++++++++++++++++++++---------------- problems.py | 8 ++-- tasks.py | 34 ++++++++++++- 3 files changed, 129 insertions(+), 47 deletions(-) diff --git a/main.py b/main.py index 9437136..dace5f2 100755 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, os, datetime +import math, sys, argparse, time, tqdm, os, datetime, warnings import torch, torchvision from torch import nn @@ -46,10 +46,12 @@ parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) ######################################## -parser.add_argument("--nb_epochs", type=int, default=25) +parser.add_argument("--nb_epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--physical_batch_size", type=int, default=None) + parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) @@ -82,7 +84,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--no_checkpoint", action="store_true", default=False) -parser.add_argument("--overwrite_results", action="store_true", default=False) +parser.add_argument("--resume", action="store_true", default=False) parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") @@ -144,6 +146,11 @@ parser.add_argument("--snake_nb_colors", type=int, default=5) parser.add_argument("--snake_length", type=int, default=200) +############################## +# ByHeart options + +parser.add_argument("--byheart_separation", type=int, default=1) + ############################## # Stack options @@ -153,7 +160,7 @@ parser.add_argument("--stack_nb_stacks", type=int, default=3) parser.add_argument("--stack_nb_digits", type=int, default=3) -parser.add_argument("--stack_fraction_values_for_train", type=float, default=0.75) +parser.add_argument("--stack_fraction_values_for_train", type=float, default=None) ############################## # Expr options @@ -367,7 +374,7 @@ else: try: os.mkdir(args.result_dir) except FileExistsError: - if not args.overwrite_results: + if not args.resume: print(f"result directory {args.result_dir} already exists") exit(1) @@ -422,6 +429,14 @@ picoclvr_pruner_eval = ( ###################################################################### +if args.physical_batch_size is None: + args.physical_batch_size = args.batch_size +else: + assert args.batch_size % args.physical_batch_size == 0 + +assert args.nb_train_samples % args.batch_size == 0 +assert args.nb_test_samples % args.batch_size == 0 + if args.task == "file": assert ( args.filetask_train_file is not None and args.filetask_test_file is not None @@ -431,7 +446,7 @@ if args.task == "file": args.filetask_test_file, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, shuffle=True, device=device, ) @@ -439,10 +454,10 @@ if args.task == "file": elif args.task == "byheart": task = tasks.SandBox( - problem=problems.ProblemByHeart(), + problem=problems.ProblemByHeart(separation=args.byheart_separation), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -453,7 +468,7 @@ elif args.task == "learnop": problem=problems.ProblemLearnOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -464,7 +479,7 @@ elif args.task == "guessop": problem=problems.ProblemGuessOperator(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -475,7 +490,7 @@ elif args.task == "twotargets": problem=problems.ProblemTwoTargets(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -485,7 +500,7 @@ elif args.task == "memory": problem=problems.ProblemMemory(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -497,7 +512,7 @@ elif args.task == "mixing": ), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -507,7 +522,7 @@ elif args.task == "addition": problem=problems.ProblemAddition(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, device=device, ) @@ -516,7 +531,7 @@ elif args.task == "picoclvr": task = tasks.PicoCLVR( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.picoclvr_height, width=args.picoclvr_width, nb_colors=args.picoclvr_nb_colors, @@ -530,7 +545,7 @@ elif args.task == "mnist": task = tasks.MNIST( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device, ) @@ -538,18 +553,18 @@ elif args.task == "maze": task = tasks.Maze( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.maze_height, width=args.maze_width, nb_walls=args.maze_nb_walls, - device=device, + device="cpu", ) elif args.task == "snake": task = tasks.Snake( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.snake_height, width=args.snake_width, nb_colors=args.snake_nb_colors, @@ -562,7 +577,7 @@ elif args.task == "stack": task = tasks.Stack( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, logger=log_string, nb_steps=args.stack_nb_steps, nb_stacks=args.stack_nb_stacks, @@ -579,7 +594,7 @@ elif args.task == "expr": sequence_length=args.expr_sequence_length, operand_max=args.expr_operand_max, result_max=args.expr_result_max, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, device=device, ) @@ -587,7 +602,7 @@ elif args.task == "rpl": task = tasks.RPL( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, nb_starting_values=args.rpl_nb_starting_values, max_input=args.rpl_max_input, prog_len=args.rpl_prog_len, @@ -601,7 +616,7 @@ elif args.task == "grid": task = tasks.Grid( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, size=args.grid_size, fraction_play=args.grid_fraction_play, logger=log_string, @@ -612,7 +627,7 @@ elif args.task == "qmlp": task = tasks.QMLP( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, device=device, @@ -622,7 +637,7 @@ elif args.task == "greed": task = tasks.Greed( nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, - batch_size=args.batch_size, + batch_size=args.physical_batch_size, height=args.greed_height, width=args.greed_width, T=args.greed_T, @@ -773,8 +788,6 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") ############################## -nb_samples_seen = 0 - if nb_epochs_finished >= args.nb_epochs: task.produce_results( n_epoch=nb_epochs_finished, @@ -802,40 +815,77 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): model.train() - nb_train_samples, acc_train_loss = 0, 0.0 + nb_train_samples, acc_train_loss_ar, acc_train_loss_ae = 0, 0.0, 0.0 for input in task.batches(split="train"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + if args.autoencoder_weight > 0: + bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) + output_ar, output_ae = bs_ar.x, bs_ae.x + loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) + loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1]) + else: + output = model(mygpt.BracketedSequence(input)).x + loss_ar = F.cross_entropy(output.transpose(1, 2), input) + loss_ae = loss_ar.new_full((1,), 0.0) + + acc_train_loss_ar += loss_ar.item() * input.size(0) + acc_train_loss_ae += loss_ae.item() * input.size(0) + nb_train_samples += input.size(0) - nb_samples_seen += input.size(0) - optimizer.zero_grad() - loss.backward() - optimizer.step() + (loss_ar + args.autoencoder_weight * loss_ae).backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() with torch.autograd.no_grad(): model.eval() - nb_test_samples, acc_test_loss = 0, 0.0 + nb_test_samples, acc_test_loss_ar, acc_test_loss_ae = 0, 0.0, 0.0 + nb_samples_accumulated = 0 for input in task.batches(split="test"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_test_loss += loss.item() * input.size(0) + if args.autoencoder_weight > 0: + bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) + output_ar, output_ae = bs_ar.x, bs_ae.x + loss_ae = F.cross_entropy( + output_ae[:, 1:].transpose(1, 2), input[:, :-1] + ) + acc_test_loss_ae += loss_ae.item() * input.size(0) + else: + bs_ar = model(mygpt.BracketedSequence(input)) + output_ar = bs_ar.x + + loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) + + acc_test_loss_ar += loss_ar.item() * input.size(0) + nb_test_samples += input.size(0) - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + train_ar_perplexity = math.exp(min(100, acc_train_loss_ar / nb_train_samples)) + test_ar_perplexity = math.exp(min(100, acc_test_loss_ar / nb_test_samples)) log_string( - f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + f"perplexity_ar {n_epoch} train_set {train_set_perplexity} train_prediction {train_ar_perplexity} test_prediction {test_ar_perplexity}" ) + if args.autoencoder_weight > 0: + train_ae_perplexity = math.exp( + min(100, acc_train_loss_ae / nb_train_samples) + ) + test_ae_perplexity = math.exp(min(100, acc_test_loss_ae / nb_test_samples)) + + log_string( + f"perplexity_ae {n_epoch} train_set {train_set_perplexity} train_prediction {train_ae_perplexity} test_prediction {test_ae_perplexity}" + ) + task.produce_results( n_epoch=n_epoch, model=model, diff --git a/problems.py b/problems.py index d7dbc54..446e1a1 100755 --- a/problems.py +++ b/problems.py @@ -200,9 +200,11 @@ class ProblemTwoTargets(Problem): class ProblemByHeart(Problem): - def __init__(self, nb_sentences=100, len_prompt=8, len_result=8): - self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = 10 + def __init__(self, nb_sentences=100, len_prompt=8, len_result=8, separation=1): + self.seq = torch.randint( + 10, (nb_sentences, len_prompt + separation + len_result) + ) + self.seq[:, len_prompt : len_prompt + separation] = 10 def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] diff --git a/tasks.py b/tasks.py index c0ad5ff..443419e 100755 --- a/tasks.py +++ b/tasks.py @@ -754,15 +754,17 @@ class Maze(Task): def compute_error( self, model, split="train", nb_to_use=-1, deterministic_synthesis=False ): + model_device = next(model.parameters()).device nb_total, nb_correct = 0, 0 count = torch.zeros( self.width * self.height, self.width * self.height, - device=self.device, + device=model_device, dtype=torch.int64, ) for input in self.batches(split, nb_to_use): + input = input.to(model_device) result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 @@ -836,7 +838,7 @@ class Maze(Task): eol = " " if j < count.size(1) - 1 else "\n" f.write(f"{count[i,j]}{eol}") - input = self.test_input[:48] + input = self.test_input[:48].to(next(model.parameters()).device) result = input.clone() ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 @@ -1098,6 +1100,34 @@ class Stack(Task): device=self.device, ) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + for label, input in [ + ("train", self.train_input[:32]), + ("test", self.test_input[:32]), + ]: + output = model(BracketedSequence(input)).x + output = output.log_softmax(dim=-1) + filename = os.path.join( + result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt" + ) + with open(filename, "w") as f: + for n in range(input.size(0)): + s = stack.seq_to_str( + input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits + ) + for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")): + u = ( + " " * (10 - len(w)) + + w + + " " + + str(output[n][t][k].exp().item()) + + "\n" + ) + f.write(u) + f.write("\n") + logger(f"wrote {filename}") + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + for n in range(result.size(0)): logger( f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}" -- 2.39.5