From 4ec52fe66419a6e1d2b231108ccbb45902395fcc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jun 2024 17:35:31 +0200 Subject: [PATCH] Update. --- main.py | 264 ++++++------------------------------------------------- tasks.py | 20 ++++- world.py | 45 +++++++--- 3 files changed, 77 insertions(+), 252 deletions(-) diff --git a/main.py b/main.py index 549e7ea..09ae823 100755 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from torch import nn from torch.nn import functional as F import ffutils -import mygpt, tasks, problems +import mygpt, tasks ###################################################################### @@ -29,8 +29,6 @@ parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) -parser.add_argument("--task", type=str, default="world", help="world") - parser.add_argument("--log_filename", type=str, default="train.log", help=" ") parser.add_argument("--result_dir", type=str, default=None) @@ -82,23 +80,20 @@ parser.add_argument("--check", action="store_true", default=False) args = parser.parse_args() if args.result_dir is None: - args.result_dir = f"results_{args.task}" + args.result_dir = f"results_culture" ###################################################################### -default_task_args = { - "world": { - "model": "37M", - "batch_size": 100, - "nb_train_samples": 250000, - "nb_test_samples": 10000, - }, +default_args = { + "model": "37M", + "batch_size": 100, + "nb_train_samples": 250000, + "nb_test_samples": 10000, } -if args.task in default_task_args: - for k, v in default_task_args[args.task].items(): - if getattr(args, k) is None: - setattr(args, k, v) +for k, v in default_args.items(): + if getattr(args, k) is None: + setattr(args, k, v) ###################################################################### @@ -199,229 +194,14 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 -if args.task == "file": - assert ( - args.filetask_train_file is not None and args.filetask_test_file is not None - ), "You have to specify the task train and test files" - task = tasks.TaskFromFile( - args.filetask_train_file, - args.filetask_test_file, - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - shuffle=True, - device=device, - ) - args.max_percents_of_test_in_train = 0 - -elif args.task == "byheart": - task = tasks.SandBox( - problem=problems.ProblemByHeart(separation=args.byheart_separation), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - args.max_percents_of_test_in_train = -1 - -elif args.task == "world": - task = tasks.World( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - result_dir=args.result_dir, - logger=log_string, - device=device, - ) - args.max_percents_of_test_in_train = -1 - -elif args.task == "learnop": - task = tasks.SandBox( - problem=problems.ProblemLearnOperator(), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - - -elif args.task == "guessop": - task = tasks.SandBox( - problem=problems.ProblemGuessOperator(), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - - -elif args.task == "twotargets": - task = tasks.SandBox( - problem=problems.ProblemTwoTargets(), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - -elif args.task == "memory": - task = tasks.SandBox( - problem=problems.ProblemMemory(), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - -elif args.task == "mixing": - task = tasks.SandBox( - problem=problems.ProblemMixing( - hard=args.mixing_hard, random_start=not args.mixing_deterministic_start - ), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - -elif args.task == "addition": - task = tasks.SandBox( - problem=problems.ProblemAddition(), - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - device=device, - ) - -elif args.task == "picoclvr": - task = tasks.PicoCLVR( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - height=args.picoclvr_height, - width=args.picoclvr_width, - nb_colors=args.picoclvr_nb_colors, - logger=log_string, - device=device, - pruner_train=picoclvr_pruner_train, - pruner_eval=picoclvr_pruner_eval, - ) - -elif args.task == "mnist": - task = tasks.MNIST( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - device=device, - ) - -elif args.task == "maze": - task = tasks.Maze( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - height=args.maze_height, - width=args.maze_width, - nb_walls=args.maze_nb_walls, - device="cpu", - ) - -elif args.task == "snake": - task = tasks.Snake( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - height=args.snake_height, - width=args.snake_width, - nb_colors=args.snake_nb_colors, - length=args.snake_length, - prompt_length=args.snake_length // 2, - device=device, - ) - -elif args.task == "stack": - task = tasks.Stack( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - logger=log_string, - nb_steps=args.stack_nb_steps, - nb_stacks=args.stack_nb_stacks, - nb_digits=args.stack_nb_digits, - fraction_values_for_train=args.stack_fraction_values_for_train, - device=device, - ) - -elif args.task == "expr": - task = tasks.Expr( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - nb_variables=args.expr_nb_variables, - sequence_length=args.expr_sequence_length, - operand_max=args.expr_operand_max, - result_max=args.expr_result_max, - batch_size=args.physical_batch_size, - device=device, - ) - -elif args.task == "rpl": - task = tasks.RPL( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - nb_starting_values=args.rpl_nb_starting_values, - max_input=args.rpl_max_input, - prog_len=args.rpl_prog_len, - nb_runs=args.rpl_nb_runs, - no_prog=args.rpl_no_prog, - logger=log_string, - device=device, - ) - -elif args.task == "grid": - task = tasks.Grid( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - size=args.grid_size, - fraction_play=args.grid_fraction_play, - logger=log_string, - device=device, - ) - -elif args.task == "qmlp": - task = tasks.QMLP( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - result_dir=args.result_dir, - logger=log_string, - device=device, - ) - -elif args.task == "greed": - task = tasks.Greed( - nb_train_samples=args.nb_train_samples, - nb_test_samples=args.nb_test_samples, - batch_size=args.physical_batch_size, - height=args.greed_height, - width=args.greed_width, - T=args.greed_T, - nb_walls=args.greed_nb_walls, - nb_coins=args.greed_nb_coins, - logger=log_string, - device=device, - ) - -else: - raise ValueError(f"Unknown task {args.task}") +task = tasks.World( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.physical_batch_size, + result_dir=args.result_dir, + logger=log_string, + device=device, +) ###################################################################### @@ -624,6 +404,10 @@ if args.check: nb_new_quizzes_for_test = 10 for n_epoch in range(args.nb_epochs): + a = [(model.id, model.main_test_accuracy) for model in models] + a.sort(key=lambda p: p[0]) + log_string(f"current accuracies {a}") + # select the model with lowest accuracy models.sort(key=lambda model: model.main_test_accuracy) model = models[0] @@ -654,5 +438,9 @@ for n_epoch in range(args.nb_epochs): nb_for_test=nb_new_quizzes_for_test, ) + # We update everyone + for model in models: + run_tests(model, task, deterministic_synthesis=False) + ###################################################################### diff --git a/tasks.py b/tasks.py index cb5900b..9a67127 100755 --- a/tasks.py +++ b/tasks.py @@ -112,6 +112,13 @@ class World(Task): nb_test_samples, height=self.height, width=self.width ).to(device) + # print() + # for a in world.seq2str(self.train_input): + # print(a) + # for a in world.seq2str(self.test_input): + # print(a) + # exit(0) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 self.train_quizzes = [] @@ -274,7 +281,7 @@ class World(Task): # Check how many of the other models can solve them in both # directions - nb_correct = 0 + nb_correct = [] for m in other_models: result = quizzes.clone() @@ -307,6 +314,13 @@ class World(Task): (reverse_quizzes == reverse_result).long().min(dim=-1).values ) - nb_correct += correct * reverse_correct + nb_correct.append((correct * reverse_correct)[None, :]) + + nb_correct = torch.cat(nb_correct, dim=0) + + filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") + with open(filename, "w") as f: + for k in nb_correct: + f.write(f"{k}\n") - return quizzes, nb_correct + return quizzes, nb_correct.sum(dim=0) diff --git a/world.py b/world.py index ab02c82..4055533 100755 --- a/world.py +++ b/world.py @@ -18,28 +18,48 @@ from torch.nn import functional as F colors = torch.tensor( [ [255, 255, 255], - [255, 0, 0], - [0, 128, 0], [0, 0, 255], - [255, 200, 0], + [0, 0, 255], + [0, 192, 0], + [0, 255, 0], + [0, 255, 127], + [0, 255, 255], + [0, 255, 255], + [30, 144, 255], + [64, 224, 208], + [65, 105, 225], + [75, 0, 130], + [106, 90, 205], + [128, 0, 128], + [135, 206, 235], [192, 192, 192], + [220, 20, 60], + [250, 128, 114], + [255, 0, 0], + [255, 0, 255], + [255, 105, 180], + [255, 127, 80], + [255, 165, 0], + [255, 182, 193], + [255, 20, 147], + [255, 200, 0], ] ) token_background = 0 first_bird_token = 1 -nb_bird_tokens = len(colors) - 1 +nb_bird_tokens = colors.size(0) - 1 token_forward = first_bird_token + nb_bird_tokens token_backward = token_forward + 1 -token2char = "_" + "".join([str(n) for n in range(len(colors) - 1)]) + "><" +token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" def generate( nb, height, width, - max_nb_obj=2, + nb_birds=2, nb_iterations=2, ): pairs = [] @@ -49,7 +69,6 @@ def generate( f_end = torch.zeros(height, width, dtype=torch.int64) n = torch.arange(f_start.size(0)) - nb_birds = torch.randint(max_nb_obj, (1,)).item() + 1 for c in ( (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values ): @@ -115,6 +134,10 @@ def sample2img(seq, height, width, upscale=15): x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) + x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 + x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 + x = x[:, :, 1:, 1:] + for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): @@ -125,9 +148,9 @@ def sample2img(seq, height, width, upscale=15): return x - direction_symbol = torch.full((direction.size(0), height * upscale, upscale), 0) + direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0) direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2) - separator = torch.full((direction.size(0), 3, height * upscale, 1), 0) + separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0) for n in range(direction_symbol.size(0)): if direction[n] == token_forward: @@ -181,7 +204,7 @@ if __name__ == "__main__": height, width = 6, 8 start_time = time.perf_counter() - seq = generate(nb=90, height=height, width=width, max_nb_obj=3) + seq = generate(nb=90, height=height, width=width) delay = time.perf_counter() - start_time print(f"{seq.size(0)/delay:02f} samples/s") @@ -194,5 +217,5 @@ if __name__ == "__main__": print(img.size()) torchvision.utils.save_image( - img.float() / 255.0, "/tmp/world.png", nrow=6, padding=4 + img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0 ) -- 2.39.5