Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 15:35:31 +0000 (17:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 15:35:31 +0000 (17:35 +0200)
main.py
tasks.py
world.py

diff --git a/main.py b/main.py
index 549e7ea..09ae823 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,7 +12,7 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
-import mygpt, tasks, problems
+import mygpt, tasks
 
 ######################################################################
 
@@ -29,8 +29,6 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--task", type=str, default="world", help="world")
-
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
 parser.add_argument("--result_dir", type=str, default=None)
@@ -82,23 +80,20 @@ parser.add_argument("--check", action="store_true", default=False)
 args = parser.parse_args()
 
 if args.result_dir is None:
-    args.result_dir = f"results_{args.task}"
+    args.result_dir = f"results_culture"
 
 ######################################################################
 
-default_task_args = {
-    "world": {
-        "model": "37M",
-        "batch_size": 100,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
+default_args = {
+    "model": "37M",
+    "batch_size": 100,
+    "nb_train_samples": 250000,
+    "nb_test_samples": 10000,
 }
 
-if args.task in default_task_args:
-    for k, v in default_task_args[args.task].items():
-        if getattr(args, k) is None:
-            setattr(args, k, v)
+for k, v in default_args.items():
+    if getattr(args, k) is None:
+        setattr(args, k, v)
 
 ######################################################################
 
@@ -199,229 +194,14 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
-if args.task == "file":
-    assert (
-        args.filetask_train_file is not None and args.filetask_test_file is not None
-    ), "You have to specify the task train and test files"
-    task = tasks.TaskFromFile(
-        args.filetask_train_file,
-        args.filetask_test_file,
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        shuffle=True,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = 0
-
-elif args.task == "byheart":
-    task = tasks.SandBox(
-        problem=problems.ProblemByHeart(separation=args.byheart_separation),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = -1
-
-elif args.task == "world":
-    task = tasks.World(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        result_dir=args.result_dir,
-        logger=log_string,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = -1
-
-elif args.task == "learnop":
-    task = tasks.SandBox(
-        problem=problems.ProblemLearnOperator(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-
-elif args.task == "guessop":
-    task = tasks.SandBox(
-        problem=problems.ProblemGuessOperator(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-
-elif args.task == "twotargets":
-    task = tasks.SandBox(
-        problem=problems.ProblemTwoTargets(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "memory":
-    task = tasks.SandBox(
-        problem=problems.ProblemMemory(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "mixing":
-    task = tasks.SandBox(
-        problem=problems.ProblemMixing(
-            hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
-        ),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "addition":
-    task = tasks.SandBox(
-        problem=problems.ProblemAddition(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "picoclvr":
-    task = tasks.PicoCLVR(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.picoclvr_height,
-        width=args.picoclvr_width,
-        nb_colors=args.picoclvr_nb_colors,
-        logger=log_string,
-        device=device,
-        pruner_train=picoclvr_pruner_train,
-        pruner_eval=picoclvr_pruner_eval,
-    )
-
-elif args.task == "mnist":
-    task = tasks.MNIST(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        device=device,
-    )
-
-elif args.task == "maze":
-    task = tasks.Maze(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.maze_height,
-        width=args.maze_width,
-        nb_walls=args.maze_nb_walls,
-        device="cpu",
-    )
-
-elif args.task == "snake":
-    task = tasks.Snake(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.snake_height,
-        width=args.snake_width,
-        nb_colors=args.snake_nb_colors,
-        length=args.snake_length,
-        prompt_length=args.snake_length // 2,
-        device=device,
-    )
-
-elif args.task == "stack":
-    task = tasks.Stack(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        nb_steps=args.stack_nb_steps,
-        nb_stacks=args.stack_nb_stacks,
-        nb_digits=args.stack_nb_digits,
-        fraction_values_for_train=args.stack_fraction_values_for_train,
-        device=device,
-    )
-
-elif args.task == "expr":
-    task = tasks.Expr(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        nb_variables=args.expr_nb_variables,
-        sequence_length=args.expr_sequence_length,
-        operand_max=args.expr_operand_max,
-        result_max=args.expr_result_max,
-        batch_size=args.physical_batch_size,
-        device=device,
-    )
-
-elif args.task == "rpl":
-    task = tasks.RPL(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        nb_starting_values=args.rpl_nb_starting_values,
-        max_input=args.rpl_max_input,
-        prog_len=args.rpl_prog_len,
-        nb_runs=args.rpl_nb_runs,
-        no_prog=args.rpl_no_prog,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "grid":
-    task = tasks.Grid(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        size=args.grid_size,
-        fraction_play=args.grid_fraction_play,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "qmlp":
-    task = tasks.QMLP(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        result_dir=args.result_dir,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "greed":
-    task = tasks.Greed(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.greed_height,
-        width=args.greed_width,
-        T=args.greed_T,
-        nb_walls=args.greed_nb_walls,
-        nb_coins=args.greed_nb_coins,
-        logger=log_string,
-        device=device,
-    )
-
-else:
-    raise ValueError(f"Unknown task {args.task}")
+task = tasks.World(
+    nb_train_samples=args.nb_train_samples,
+    nb_test_samples=args.nb_test_samples,
+    batch_size=args.physical_batch_size,
+    result_dir=args.result_dir,
+    logger=log_string,
+    device=device,
+)
 
 ######################################################################
 
@@ -624,6 +404,10 @@ if args.check:
     nb_new_quizzes_for_test = 10
 
 for n_epoch in range(args.nb_epochs):
+    a = [(model.id, model.main_test_accuracy) for model in models]
+    a.sort(key=lambda p: p[0])
+    log_string(f"current accuracies {a}")
+
     # select the model with lowest accuracy
     models.sort(key=lambda model: model.main_test_accuracy)
     model = models[0]
@@ -654,5 +438,9 @@ for n_epoch in range(args.nb_epochs):
             nb_for_test=nb_new_quizzes_for_test,
         )
 
+        # We update everyone
+        for model in models:
+            run_tests(model, task, deterministic_synthesis=False)
+
 
 ######################################################################
index cb5900b..9a67127 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -112,6 +112,13 @@ class World(Task):
             nb_test_samples, height=self.height, width=self.width
         ).to(device)
 
+        # print()
+        # for a in world.seq2str(self.train_input):
+        # print(a)
+        # for a in world.seq2str(self.test_input):
+        # print(a)
+        # exit(0)
+
         self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
 
         self.train_quizzes = []
@@ -274,7 +281,7 @@ class World(Task):
         # Check how many of the other models can solve them in both
         # directions
 
-        nb_correct = 0
+        nb_correct = []
 
         for m in other_models:
             result = quizzes.clone()
@@ -307,6 +314,13 @@ class World(Task):
                 (reverse_quizzes == reverse_result).long().min(dim=-1).values
             )
 
-            nb_correct += correct * reverse_correct
+            nb_correct.append((correct * reverse_correct)[None, :])
+
+        nb_correct = torch.cat(nb_correct, dim=0)
+
+        filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+        with open(filename, "w") as f:
+            for k in nb_correct:
+                f.write(f"{k}\n")
 
-        return quizzes, nb_correct
+        return quizzes, nb_correct.sum(dim=0)
index ab02c82..4055533 100755 (executable)
--- a/world.py
+++ b/world.py
@@ -18,28 +18,48 @@ from torch.nn import functional as F
 colors = torch.tensor(
     [
         [255, 255, 255],
-        [255, 0, 0],
-        [0, 128, 0],
         [0, 0, 255],
-        [255, 200, 0],
+        [0, 0, 255],
+        [0, 192, 0],
+        [0, 255, 0],
+        [0, 255, 127],
+        [0, 255, 255],
+        [0, 255, 255],
+        [30, 144, 255],
+        [64, 224, 208],
+        [65, 105, 225],
+        [75, 0, 130],
+        [106, 90, 205],
+        [128, 0, 128],
+        [135, 206, 235],
         [192, 192, 192],
+        [220, 20, 60],
+        [250, 128, 114],
+        [255, 0, 0],
+        [255, 0, 255],
+        [255, 105, 180],
+        [255, 127, 80],
+        [255, 165, 0],
+        [255, 182, 193],
+        [255, 20, 147],
+        [255, 200, 0],
     ]
 )
 
 token_background = 0
 first_bird_token = 1
-nb_bird_tokens = len(colors) - 1
+nb_bird_tokens = colors.size(0) - 1
 token_forward = first_bird_token + nb_bird_tokens
 token_backward = token_forward + 1
 
-token2char = "_" + "".join([str(n) for n in range(len(colors) - 1)]) + "><"
+token2char = "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><"
 
 
 def generate(
     nb,
     height,
     width,
-    max_nb_obj=2,
+    nb_birds=2,
     nb_iterations=2,
 ):
     pairs = []
@@ -49,7 +69,6 @@ def generate(
         f_end = torch.zeros(height, width, dtype=torch.int64)
         n = torch.arange(f_start.size(0))
 
-        nb_birds = torch.randint(max_nb_obj, (1,)).item() + 1
         for c in (
             (torch.randperm(nb_bird_tokens) + first_bird_token)[:nb_birds].sort().values
         ):
@@ -115,6 +134,10 @@ def sample2img(seq, height, width, upscale=15):
         x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale)
         x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale)
 
+        x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0
+        x[:, :, torch.arange(0, x.size(2), upscale), :] = 0
+        x = x[:, :, 1:, 1:]
+
         for n in range(m.size(0)):
             for i in range(m.size(1)):
                 for j in range(m.size(2)):
@@ -125,9 +148,9 @@ def sample2img(seq, height, width, upscale=15):
 
         return x
 
-    direction_symbol = torch.full((direction.size(0), height * upscale, upscale), 0)
+    direction_symbol = torch.full((direction.size(0), height * upscale - 1, upscale), 0)
     direction_symbol = colors[direction_symbol].permute(0, 3, 1, 2)
-    separator = torch.full((direction.size(0), 3, height * upscale, 1), 0)
+    separator = torch.full((direction.size(0), 3, height * upscale - 1, 1), 0)
 
     for n in range(direction_symbol.size(0)):
         if direction[n] == token_forward:
@@ -181,7 +204,7 @@ if __name__ == "__main__":
 
     height, width = 6, 8
     start_time = time.perf_counter()
-    seq = generate(nb=90, height=height, width=width, max_nb_obj=3)
+    seq = generate(nb=90, height=height, width=width)
     delay = time.perf_counter() - start_time
     print(f"{seq.size(0)/delay:02f} samples/s")
 
@@ -194,5 +217,5 @@ if __name__ == "__main__":
     print(img.size())
 
     torchvision.utils.save_image(
-        img.float() / 255.0, "/tmp/world.png", nrow=6, padding=4
+        img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0
     )