Update.
[culture.git] / main.py
diff --git a/main.py b/main.py
index b57c512..b4ab473 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,32 +12,27 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
 from torch.nn import functional as F
 
 import ffutils
-import mygpt, tasks, problems
 
 
-######################################################################
+import mygpt
+import sky, grids, quiz_machine
 
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
+import threading
+
+import torch.multiprocessing as mp
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache.",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument("--task", type=str, default="world", help="world")
-
-parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+parser.add_argument("--log_filename", type=str, default="train.log")
 
 parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
 
 parser.add_argument("--result_dir", type=str, default=None)
 
 parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
 ########################################
 
 
 ########################################
 
@@ -51,7 +46,7 @@ parser.add_argument("--nb_train_samples", type=int, default=None)
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
 
 parser.add_argument("--nb_test_samples", type=int, default=None)
 
-parser.add_argument("--learning_rate", type=float, default=1e-4)
+parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 ########################################
 
 
 ########################################
 
@@ -73,30 +68,78 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument("--check", action="store_true", default=False)
+parser.add_argument("--problem", type=str, default="grids")
+
+parser.add_argument("--nb_threads", type=int, default=1)
+
+parser.add_argument("--gpus", type=str, default="all")
+
+parser.add_argument("--nb_gpts", type=int, default=5)
+
+parser.add_argument("--min_to_validate", type=int, default=None)
+
+parser.add_argument("--max_to_validate", type=int, default=None)
+
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+
+parser.add_argument("--proba_understands", type=float, default=0.99)
+
+parser.add_argument("--proba_not_understands", type=float, default=0.5)
+
+parser.add_argument("--generation_temperature", type=float, default=2.0)
+
+parser.add_argument("--dirty_debug", action="store_true", default=False)
+
+######################################################################
+
+grids_tasks = ", ".join(
+    [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
+
+parser.add_argument(
+    "--grids_tasks",
+    type=str,
+    default=None,
+    help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
+
+######################################################################
+
+parser.add_argument("--sky_height", type=int, default=6)
+
+parser.add_argument("--sky_width", type=int, default=8)
+
+parser.add_argument("--sky_nb_birds", type=int, default=3)
+
+parser.add_argument("--sky_nb_iterations", type=int, default=2)
+
+parser.add_argument("--sky_speed", type=int, default=3)
 
 ######################################################################
 
 args = parser.parse_args()
 
 
 ######################################################################
 
 args = parser.parse_args()
 
+if args.min_to_validate is None:
+    args.min_to_validate = args.nb_gpts - 1
+
+if args.max_to_validate is None:
+    args.max_to_validate = args.nb_gpts - 1
+
 if args.result_dir is None:
 if args.result_dir is None:
-    args.result_dir = f"results_{args.task}"
+    args.result_dir = f"results_culture"
 
 ######################################################################
 
 
 ######################################################################
 
-default_task_args = {
-    "world": {
-        "model": "37M",
-        "batch_size": 100,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
+default_args = {
+    "model": "37M",
+    "batch_size": 25,
+    "nb_train_samples": 100000,
+    "nb_test_samples": 10000,
 }
 
 }
 
-if args.task in default_task_args:
-    for k, v in default_task_args[args.task].items():
-        if getattr(args, k) is None:
-            setattr(args, k, v)
+for k, v in default_args.items():
+    if getattr(args, k) is None:
+        setattr(args, k, v)
 
 ######################################################################
 
 
 ######################################################################
 
@@ -185,9 +228,22 @@ for n in vars(args):
 
 ######################################################################
 
 
 ######################################################################
 
-if args.test:
-    args.nb_train_samples = 1000
-    args.nb_test_samples = 25
+if args.gpus == "all":
+    gpus_idx = range(torch.cuda.device_count())
+else:
+    gpus_idx = [int(k) for k in args.gpus.split(",")]
+
+gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+
+if torch.cuda.is_available():
+    main_device = gpus[0]
+else:
+    assert len(gpus) == 0
+    main_device = torch.device("cpu")
+
+if args.dirty_debug:
+    args.nb_train_samples = 2500
+    args.nb_test_samples = 100
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
 
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
@@ -197,296 +253,93 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
-if args.task == "file":
-    assert (
-        args.filetask_train_file is not None and args.filetask_test_file is not None
-    ), "You have to specify the task train and test files"
-    task = tasks.TaskFromFile(
-        args.filetask_train_file,
-        args.filetask_test_file,
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        shuffle=True,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = 0
-
-elif args.task == "byheart":
-    task = tasks.SandBox(
-        problem=problems.ProblemByHeart(separation=args.byheart_separation),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = -1
-
-elif args.task == "world":
-    task = tasks.World(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        result_dir=args.result_dir,
-        logger=log_string,
-        device=device,
-    )
-    args.max_percents_of_test_in_train = -1
-
-elif args.task == "learnop":
-    task = tasks.SandBox(
-        problem=problems.ProblemLearnOperator(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-
-elif args.task == "guessop":
-    task = tasks.SandBox(
-        problem=problems.ProblemGuessOperator(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-
-elif args.task == "twotargets":
-    task = tasks.SandBox(
-        problem=problems.ProblemTwoTargets(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "memory":
-    task = tasks.SandBox(
-        problem=problems.ProblemMemory(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "mixing":
-    task = tasks.SandBox(
-        problem=problems.ProblemMixing(
-            hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
-        ),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "addition":
-    task = tasks.SandBox(
-        problem=problems.ProblemAddition(),
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "picoclvr":
-    task = tasks.PicoCLVR(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.picoclvr_height,
-        width=args.picoclvr_width,
-        nb_colors=args.picoclvr_nb_colors,
-        logger=log_string,
-        device=device,
-        pruner_train=picoclvr_pruner_train,
-        pruner_eval=picoclvr_pruner_eval,
-    )
-
-elif args.task == "mnist":
-    task = tasks.MNIST(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        device=device,
-    )
-
-elif args.task == "maze":
-    task = tasks.Maze(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.maze_height,
-        width=args.maze_width,
-        nb_walls=args.maze_nb_walls,
-        device="cpu",
-    )
-
-elif args.task == "snake":
-    task = tasks.Snake(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.snake_height,
-        width=args.snake_width,
-        nb_colors=args.snake_nb_colors,
-        length=args.snake_length,
-        prompt_length=args.snake_length // 2,
-        device=device,
-    )
-
-elif args.task == "stack":
-    task = tasks.Stack(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        logger=log_string,
-        nb_steps=args.stack_nb_steps,
-        nb_stacks=args.stack_nb_stacks,
-        nb_digits=args.stack_nb_digits,
-        fraction_values_for_train=args.stack_fraction_values_for_train,
-        device=device,
-    )
-
-elif args.task == "expr":
-    task = tasks.Expr(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        nb_variables=args.expr_nb_variables,
-        sequence_length=args.expr_sequence_length,
-        operand_max=args.expr_operand_max,
-        result_max=args.expr_result_max,
-        batch_size=args.physical_batch_size,
-        device=device,
-    )
-
-elif args.task == "rpl":
-    task = tasks.RPL(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        nb_starting_values=args.rpl_nb_starting_values,
-        max_input=args.rpl_max_input,
-        prog_len=args.rpl_prog_len,
-        nb_runs=args.rpl_nb_runs,
-        no_prog=args.rpl_no_prog,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "grid":
-    task = tasks.Grid(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        size=args.grid_size,
-        fraction_play=args.grid_fraction_play,
-        logger=log_string,
-        device=device,
-    )
-
-elif args.task == "qmlp":
-    task = tasks.QMLP(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        result_dir=args.result_dir,
-        logger=log_string,
-        device=device,
+if args.problem == "sky":
+    problem = sky.Sky(
+        height=args.sky_height,
+        width=args.sky_width,
+        nb_birds=args.sky_nb_birds,
+        nb_iterations=args.sky_nb_iterations,
+        speed=args.sky_speed,
+        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+        chunk_size=100,
+        nb_threads=args.nb_threads,
     )
     )
-
-elif args.task == "greed":
-    task = tasks.Greed(
-        nb_train_samples=args.nb_train_samples,
-        nb_test_samples=args.nb_test_samples,
-        batch_size=args.physical_batch_size,
-        height=args.greed_height,
-        width=args.greed_width,
-        T=args.greed_T,
-        nb_walls=args.greed_nb_walls,
-        nb_coins=args.greed_nb_coins,
-        logger=log_string,
-        device=device,
+    back_accuracy = False
+elif args.problem == "grids":
+    problem = grids.Grids(
+        max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+        chunk_size=100,
+        nb_threads=args.nb_threads,
+        tasks=args.grids_tasks,
     )
     )
-
+    back_accuracy = True
 else:
 else:
-    raise ValueError(f"Unknown task {args.task}")
+    raise ValueError
+
+problem.save_some_examples(args.result_dir)
+
+quiz_machine = quiz_machine.QuizMachine(
+    problem=problem,
+    nb_train_samples=args.nb_train_samples,
+    nb_test_samples=args.nb_test_samples,
+    back_accuracy=back_accuracy,
+    batch_size=args.physical_batch_size,
+    result_dir=args.result_dir,
+    logger=log_string,
+    device=main_device,
+)
 
 ######################################################################
 
 
 ######################################################################
 
-log_string(f"device {device}")
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
 
 
-vocabulary_size = task.vocabulary_size()
+vocabulary_size = quiz_machine.vocabulary_size()
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
-# Compute the entropy of the training tokens
 
 
-token_count = 0
-for input in task.batches(split="train", desc="train-entropy"):
-    token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+    with torch.autograd.no_grad():
+        model.eval().to(local_device)
 
 
-######################################################################
-# A bit of paranoia never hurts
+        nb_test_samples, acc_test_loss = 0, 0.0
+        nb_samples_accumulated = 0
 
 
-if args.max_percents_of_test_in_train >= 0:
+        for input in quiz_machine.batches(model, split="test"):
+            input = input.to(local_device)
 
 
-    def subsets_as_tuples(batches, cs):
-        s = set()
-        for batch in batches:
-            for x in batch:
-                s.add(tuple([v.item() for v in x]))
-                if len(s) == cs:
-                    yield s
-                    s = set()
-        yield s
+            bs = model(mygpt.BracketedSequence(input))
+            output = bs.x
 
 
-    nb_test, nb_in_train = 0, 0
-    for test_subset in subsets_as_tuples(
-        task.batches(split="test", desc="test-check"), 25000
-    ):
-        in_train = set()
-        for train_subset in subsets_as_tuples(
-            task.batches(split="train", desc="train-check"), 25000
-        ):
-            in_train.update(test_subset.intersection(train_subset))
-        nb_in_train += len(in_train)
-        nb_test += len(test_subset)
+            loss = F.cross_entropy(output.transpose(1, 2), input)
 
 
-    log_string(
-        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
-    )
+            acc_test_loss += loss.item() * input.size(0)
 
 
-    assert (
-        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
-    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+            nb_test_samples += input.size(0)
+
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+        log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
 
-##############################
+        model.main_test_accuracy = quiz_machine.produce_results(
+            n_epoch=n_epoch,
+            model=model,
+            result_dir=args.result_dir,
+            deterministic_synthesis=deterministic_synthesis,
+        )
 
 
 
 
-def one_epoch(model, task):
+def one_epoch(model, quiz_machine, local_device=main_device):
     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
-    model.train()
+    model.to(local_device).train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in task.batches(split="train"):
-        input = input.to(device)
+    for input in quiz_machine.batches(model, split="train"):
+        input = input.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
 
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
@@ -504,90 +357,101 @@ def one_epoch(model, task):
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
 
-    log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+    log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
+
+    run_tests(model, quiz_machine, deterministic_synthesis=False)
+
+    model.to(main_device)
 
 
 ######################################################################
 
 
 
 
 ######################################################################
 
 
-def run_tests(model, task, deterministic_synthesis):
-    with torch.autograd.no_grad():
-        model.eval()
+def standard_validity(logproba):
+    l = logproba.sort(dim=-1).values
+    return (l[:, 0] < math.log(args.proba_not_understands)) & (
+        l[:, 1] > math.log(args.proba_understands)
+    )
 
 
-        nb_test_samples, acc_test_loss = 0, 0.0
-        nb_samples_accumulated = 0
 
 
-        for input in task.batches(split="test"):
-            input = input.to(device)
+def valid_c_quizzes(recorded, criteria):
+    result = [q[criteria(lp)] for q, lp in recorded]
+    return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
 
 
-            bs = model(mygpt.BracketedSequence(input))
-            output = bs.x
 
 
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+######################################################################
 
 
-            acc_test_loss += loss.item() * input.size(0)
 
 
-            nb_test_samples += input.size(0)
+def create_c_quizzes(
+    models,
+    quiz_machine,
+    nb_for_train=1000,
+    nb_for_test=100,
+):
+    quizzes_and_logproba_records = []
 
 
-        main_test_accuracy = task.produce_results(
-            n_epoch=n_epoch,
-            model=model,
-            result_dir=args.result_dir,
-            logger=log_string,
-            deterministic_synthesis=deterministic_synthesis,
-        )
+    nb_to_create = nb_for_train + nb_for_test
 
 
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+    # ------------------------------------------------------------
 
 
-        log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+    file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
 
 
-    model.main_test_accuracy = main_test_accuracy
+    with open(file_name, "w") as logp_file:
+        while (
+            valid_c_quizzes(quizzes_and_logproba_records, standard_validity).size(0)
+            < nb_to_create
+        ):
+            # Select a model at random to generate the new quizzes
 
 
+            model_for_generation = models[torch.randint(len(models), (1,))]
 
 
-######################################################################
+            c_quizzes = quiz_machine.generate_quizzes(
+                nb_to_create,
+                model_for_generation=model_for_generation,
+                temperature=args.generation_temperature,
+            )
 
 
+            c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
 
 
-def create_quizzes(
-    model,
-    other_models,
-    task,
-    nb_for_train=1000,
-    nb_for_test=100,
-):
-    kept = []
+            if c_quizzes.size(0) > 0:
+                logproba = quiz_machine.logproba_of_solutions(models, c_quizzes)
+                for l in logproba:
+                    s = " ".join([str(x.item()) for x in l])
+                    logp_file.write(s + "\n")
+                quizzes_and_logproba_records.append((c_quizzes, logproba))
 
 
-    while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
-        new_quizzes, nb_correct = task.create_new_quizzes(
-            n_epoch=n_epoch,
-            result_dir=args.result_dir,
-            logger=log_string,
-            nb=4 * (nb_for_train + nb_for_test),
-            model=model,
-            other_models=other_models,
-        )
+            nb_validated = valid_c_quizzes(
+                quizzes_and_logproba_records, standard_validity
+            ).size(0)
 
 
-        to_keep = new_quizzes[nb_correct == len(other_models) - 1]
-        log_string(f"keep {to_keep.size(0)} quizzes")
-        kept.append(to_keep)
+            log_string(
+                f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+            )
 
 
-    new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+    # store the new c_quizzes which have been validated
 
 
-    task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
-    task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
+    new_c_quizzes = valid_c_quizzes(quizzes_and_logproba_records, standard_validity)
 
 
-    task.save_image(
-        new_quizzes[:96],
-        args.result_dir,
-        f"world_new_{n_epoch:04d}_{model.id:02d}.png",
-        log_string,
-    )
+    quiz_machine.reverse_random_half_in_place(new_c_quizzes)
+
+    quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+    quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+
+    # save a bunch of images to investigate what quizzes with a
+    # certain nb of correct predictions look like
+
+    q = new_c_quizzes[:72]
+
+    if q.size(0) > 0:
+        quiz_machine.save_quizzes(args.result_dir, f"culture_c_quiz_{n_epoch:04d}", q)
 
 
 ######################################################################
 
 models = []
 
 
 
 ######################################################################
 
 models = []
 
-for k in range(5):
+for k in range(args.nb_gpts):
+    log_string(f"creating model {k} and its w_quizzes")
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
@@ -597,11 +461,16 @@ for k in range(5):
         nb_blocks=args.nb_blocks,
         causal=True,
         dropout=args.dropout,
         nb_blocks=args.nb_blocks,
         causal=True,
         dropout=args.dropout,
-    ).to(device)
+    ).to(main_device)
 
     model.main_test_accuracy = 0.0
     model.id = k
 
 
     model.main_test_accuracy = 0.0
     model.id = k
 
+    model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
+    quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
+    model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
+    quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+
     models.append(model)
 
 
     models.append(model)
 
 
@@ -610,45 +479,127 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
 
 ######################################################################
 
-accuracy_to_make_quizzes = 0.975
-nb_new_quizzes_for_train = 1000
-nb_new_quizzes_for_test = 100
+# Compute the entropy of the training tokens
 
 
-if args.test:
-    accuracy_to_make_quizzes = 0.0
-    nb_new_quizzes_for_train = 10
-    nb_new_quizzes_for_test = 10
+token_count = 0
+for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
+    token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
+        (0, 1)
+    )
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
 
 
-for n_epoch in range(args.nb_epochs):
-    # select the model with lowest accuracy
-    models.sort(key=lambda model: model.main_test_accuracy)
-    model = models[0]
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+    def subsets_as_tuples(batches, cs):
+        s = set()
+        for batch in batches:
+            for x in batch:
+                s.add(tuple([v.item() for v in x]))
+                if len(s) == cs:
+                    yield s
+                    s = set()
+        yield s
+
+    nb_test, nb_in_train = 0, 0
+    for test_subset in subsets_as_tuples(
+        quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
+    ):
+        in_train = set()
+        for train_subset in subsets_as_tuples(
+            quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
+        ):
+            in_train.update(test_subset.intersection(train_subset))
+        nb_in_train += len(in_train)
+        nb_test += len(test_subset)
 
     log_string(
 
     log_string(
-        f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+        f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
     )
 
     )
 
-    # improve it
-    one_epoch(model, task)
+    assert (
+        nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+    ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+######################################################################
+
+nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+
+log_string(
+    f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+)
+
+######################################################################
+
+if args.dirty_debug:
+    args.accuracy_to_make_c_quizzes = 0.0
+    args.nb_gpts = 2
+    nb_new_c_quizzes_for_train = 100
+    nb_new_c_quizzes_for_test = 10
+
+    def standard_validity(logproba):
+        l = logproba.sort(dim=-1).values
+        return l[:, 0] < math.log(0.5)
+
+
+######################################################################
+
+for n_epoch in range(args.nb_epochs):
+    log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+    log_string(f"current_test_accuracies {cta}")
+
+    ##################################################
+    # Select, improve, and eval the worst model
+
+    ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+
+    weakest_models = ranked_models[: len(gpus)]
+
+    threads = []
+
+    for gpu, model in zip(gpus, weakest_models):
+        log_string(f"training model {model.id}")
+
+        t = threading.Thread(
+            target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+        )
+
+        threads.append(t)
+
+        t.start()
+
+    for t in threads:
+        t.join()
+
+    ##################################################
+    # Replace a fraction of the w_quizzes with fresh ones
 
     log_string(
 
     log_string(
-        f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+        f"cache_w_quizzes contains {quiz_machine.problem.nb_cached_quizzes()} quizzes"
     )
 
     )
 
-    # test it
-    run_tests(model, task, deterministic_synthesis=False)
+    # Renew entirely the train set
 
 
-    if model.main_test_accuracy >= accuracy_to_make_quizzes:
-        other_models = models.copy()
-        other_models.remove(model)
+    for model in weakest_models:
+        quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
 
 
-        create_quizzes(
-            model,
-            other_models,
-            task,
-            nb_for_train=nb_new_quizzes_for_train,
-            nb_for_test=nb_new_quizzes_for_test,
-        )
+    ##################################################
+    # If all the models are good enough, generate new quizzes and
+    # re-compute the test errors
 
 
+    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+        create_c_quizzes(
+            models,
+            quiz_machine,
+            nb_for_train=nb_new_c_quizzes_for_train,
+            nb_for_test=nb_new_c_quizzes_for_test,
+        )
 
 ######################################################################
 
 ######################################################################