# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, argparse, time, tqdm, os, datetime, warnings
+import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
import torch, torchvision
from torch import nn
from torch.nn import functional as F
-import ffutils
-import mygpt, tasks, problems
+import ffutils, grids, attae
-######################################################################
+import threading, subprocess
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
+# import torch.multiprocessing as mp
+
+torch.set_float32_matmul_precision("high")
+
+# torch.set_default_dtype(torch.bfloat16)
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument("--task", type=str, default="world", help="world")
-
-parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
+parser.add_argument("--log_filename", type=str, default="train.log")
parser.add_argument("--result_dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=0)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--resume", action="store_true", default=False)
-########################################
+# ----------------------------------
parser.add_argument("--nb_epochs", type=int, default=10000)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
+
+parser.add_argument("--train_batch_size", type=int, default=None)
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--eval_batch_size", type=int, default=25)
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
-parser.add_argument("--learning_rate", type=float, default=1e-4)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
-########################################
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--learning_rate", type=float, default=5e-4)
+
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+# ----------------------------------
+
+parser.add_argument("--model_type", type=str, default="standard")
+
+parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--nb_blocks", type=int, default=None)
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
+
+# ----------------------------------
-########################################
+parser.add_argument("--nb_threads", type=int, default=1)
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--gpus", type=str, default="all")
+
+# ----------------------------------
+
+parser.add_argument("--nb_models", type=int, default=5)
+
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
+
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
+
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
+
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
+
+parser.add_argument("--proba_hint", type=float, default=0.25)
+
+parser.add_argument("--quizzes", type=str, default=None)
######################################################################
-args = parser.parse_args()
+grids_tasks = ", ".join(
+ [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
-if args.result_dir is None:
- args.result_dir = f"results_{args.task}"
+parser.add_argument(
+ "--grids_world_tasks",
+ type=str,
+ default="replace_color,translate,grow,frame",
+ help="A comma-separated subset of: " + grids_tasks + ".",
+)
######################################################################
-default_task_args = {
- "world": {
- "model": "37M",
- "batch_size": 100,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
-}
+args = parser.parse_args()
-if args.task in default_task_args:
- for k, v in default_task_args[args.task].items():
- if getattr(args, k) is None:
- setattr(args, k, v)
+if args.result_dir is None:
+ args.result_dir = f"results_culture"
######################################################################
######################################################################
-try:
- os.mkdir(args.result_dir)
-except FileExistsError:
- print(f"result directory {args.result_dir} already exists")
- exit(1)
+if args.resume:
+ if not os.path.isdir(args.result_dir):
+ print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
+ exit(1)
+else:
+ try:
+ os.mkdir(args.result_dir)
+ except FileExistsError:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
def log_string(s):
+ """print the given string prefixed with a time stamps, and log it
+ into log_file is not None"""
+
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
sys.stdout.flush()
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+ f.write(f"{' '.join(sys.argv)}\n")
+
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
+
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
######################################################################
+if args.gpus == "all":
+ gpus_idx = range(torch.cuda.device_count())
+else:
+ gpus_idx = [int(k) for k in args.gpus.split(",")]
-if args.physical_batch_size is None:
- args.physical_batch_size = args.batch_size
+gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+
+if torch.cuda.is_available():
+ main_device = gpus[0]
+else:
+ assert len(gpus) == 0
+ main_device = torch.device("cpu")
+
+if args.train_batch_size is None:
+ args.train_batch_size = args.batch_size
else:
- assert args.batch_size % args.physical_batch_size == 0
+ assert args.batch_size % args.train_batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.task == "file":
- assert (
- args.filetask_train_file is not None and args.filetask_test_file is not None
- ), "You have to specify the task train and test files"
- task = tasks.TaskFromFile(
- args.filetask_train_file,
- args.filetask_test_file,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- shuffle=True,
- device=device,
- )
- args.max_percents_of_test_in_train = 0
-
-elif args.task == "byheart":
- task = tasks.SandBox(
- problem=problems.ProblemByHeart(separation=args.byheart_separation),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
- args.max_percents_of_test_in_train = -1
-
-elif args.task == "world":
- task = tasks.World(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=device,
- )
- args.max_percents_of_test_in_train = -1
-
-elif args.task == "learnop":
- task = tasks.SandBox(
- problem=problems.ProblemLearnOperator(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
+######################################################################
-elif args.task == "guessop":
- task = tasks.SandBox(
- problem=problems.ProblemGuessOperator(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
+def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
-elif args.task == "twotargets":
- task = tasks.SandBox(
- problem=problems.ProblemTwoTargets(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
+######################################################################
-elif args.task == "memory":
- task = tasks.SandBox(
- problem=problems.ProblemMemory(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
-elif args.task == "mixing":
- task = tasks.SandBox(
- problem=problems.ProblemMixing(
- hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
- ),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+ if c_quizzes is None:
+ quizzes = problem.generate_w_quizzes(nb_samples)
+ nb_w_quizzes = quizzes.size(0)
+ nb_c_quizzes = 0
+ else:
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
-elif args.task == "addition":
- task = tasks.SandBox(
- problem=problems.ProblemAddition(),
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- device=device,
- )
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
-elif args.task == "picoclvr":
- task = tasks.PicoCLVR(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- height=args.picoclvr_height,
- width=args.picoclvr_width,
- nb_colors=args.picoclvr_nb_colors,
- logger=log_string,
- device=device,
- pruner_train=picoclvr_pruner_train,
- pruner_eval=picoclvr_pruner_eval,
- )
+ w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
-elif args.task == "mnist":
- task = tasks.MNIST(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- device=device,
- )
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ nb_w_quizzes = w_quizzes.size(0)
+ nb_c_quizzes = c_quizzes.size(0)
-elif args.task == "maze":
- task = tasks.Maze(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- height=args.maze_height,
- width=args.maze_width,
- nb_walls=args.maze_nb_walls,
- device="cpu",
- )
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes = quizzes[i].contiguous()
+
+ log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+ return quizzes
+
+
+######################################################################
+
+
+def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
+ input, masks, targets = imt_set.unbind(dim=1)
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.proba_hint
+ ).long() * masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
+ input, masks, targets = imt_set.unbind(dim=1)
+ noise = problem.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+######################################################################
+# Prediction
+
+
+def samples_for_prediction_imt(input):
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def ae_predict(model, imt_set, local_device=main_device):
+ model.eval().to(local_device)
+
+ record = []
-elif args.task == "snake":
- task = tasks.Snake(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- height=args.snake_height,
- width=args.snake_width,
- nb_colors=args.snake_nb_colors,
- length=args.snake_length,
- prompt_length=args.snake_length // 2,
- device=device,
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
)
-elif args.task == "stack":
- task = tasks.Stack(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- logger=log_string,
- nb_steps=args.stack_nb_steps,
- nb_stacks=args.stack_nb_stacks,
- nb_digits=args.stack_nb_digits,
- fraction_values_for_train=args.stack_fraction_values_for_train,
- device=device,
+ for imt in src:
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
+ record.append(result)
+
+ return torch.cat(record)
+
+
+def predict_the_four_grids(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+ if with_hints:
+ imt_set = add_hints_imt(imt_set)
+
+ if with_noise:
+ imt_set = add_noise_imt(imt_set)
+
+ result = ae_predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+ return result
+
+
+######################################################################
+
+
+def samples_for_generation_imt(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
)
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+
+ noise = problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ masks = input.new_full(input.size(), 1)
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def ae_generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
+
+ # We loop through the iterations first and through the
+ # mini-batches second so that we keep only the samples that have
+ # not stabilized
+
+ all_input = problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
+
+ for it in range(args.diffusion_nb_iterations):
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+ if not all_changed.any():
+ break
+
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),
+ )
+
+ for input, masks, changed in src:
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+ update = (1 - mask_changes) * input + mask_changes * output
+ changed[...] = changed & (update != input).max(dim=1).values
+ input[...] = update
+
+ a = all_changed.clone()
+ all_input[a] = sub_input
+ all_masks[a] = sub_masks
+ all_changed[a] = sub_changed
+
+ return all_input
+
+
+######################################################################
+
-elif args.task == "expr":
- task = tasks.Expr(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- nb_variables=args.expr_nb_variables,
- sequence_length=args.expr_sequence_length,
- operand_max=args.expr_operand_max,
- result_max=args.expr_result_max,
- batch_size=args.physical_batch_size,
- device=device,
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
+ quizzes = generate_quiz_set(
+ args.nb_train_samples if train else args.nb_test_samples,
+ c_quizzes,
+ args.c_quiz_multiplier,
)
-elif args.task == "rpl":
- task = tasks.RPL(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- nb_starting_values=args.rpl_nb_starting_values,
- max_input=args.rpl_max_input,
- prog_len=args.rpl_prog_len,
- nb_runs=args.rpl_nb_runs,
- no_prog=args.rpl_no_prog,
- logger=log_string,
- device=device,
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
+ b_p = samples_for_prediction_imt(q_p)
+ b_p = add_noise_imt(b_p)
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
+
+ # The other half are denoising examples for the generation
+ b_g = samples_for_generation_imt(q_g)
+
+ imt_set = torch.cat([b_p, b_g])
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
+ else:
+ label = "test"
+ model.eval().to(local_device)
+ batch_size = args.eval_batch_size
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for imt in tqdm.tqdm(
+ imt_set.split(batch_size),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // batch_size,
+ delay=10,
+ ):
+ input, masks, targets = imt.unbind(dim=1)
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * masks).mean()
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
+
+ if train:
+ loss.backward()
+
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
+
+
+######################################################################
+
+
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+ # Save some images of the prediction results
+
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
+
+ correct = (quizzes == result).min(dim=1).values.long()
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+ :, :, 1
+ ]
+ predicted_parts = correct_parts.abs()
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
)
-elif args.task == "grid":
- task = tasks.Grid(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- size=args.grid_size,
- fraction_play=args.grid_fraction_play,
- logger=log_string,
- device=device,
+ # Save some images of the ex nihilo generation of the four grids
+
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
)
-elif args.task == "qmlp":
- task = tasks.QMLP(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=device,
+
+######################################################################
+
+
+def one_complete_epoch(
+ model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+ one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
+
+ one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
+
+ # Compute the test accuracy
+
+ quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ correct = (quizzes == result).min(dim=1).values.long()
+
+ nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
+ model.test_accuracy = nb_correct / nb_total
+
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
)
-elif args.task == "greed":
- task = tasks.Greed(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.physical_batch_size,
- height=args.greed_height,
- width=args.greed_width,
- T=args.greed_T,
- nb_walls=args.greed_nb_walls,
- nb_coins=args.greed_nb_coins,
- logger=log_string,
- device=device,
+ save_inference_images(
+ model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
)
-else:
- raise ValueError(f"Unknown task {args.task}")
######################################################################
-log_string(f"device {device}")
-vocabulary_size = task.vocabulary_size()
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+ return (
+ (prediction != quizzes)
+ .long()
+ .reshape(quizzes.size(0), 4, -1)
+ .sum(dim=2)
+ .max(dim=1)
+ .values
+ )
-log_string(f"vocabulary_size {vocabulary_size}")
-######################################################################
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
+ nb_correct, nb_wrong = 0, 0
+
+ for model in models:
+ model = copy.deepcopy(model).to(local_device).eval()
+ predicted = predict_the_four_grids(
+ model=model,
+ input=quizzes,
+ with_noise=False,
+ with_hints=with_hints,
+ local_device=local_device,
+ )
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
+ nb_correct += (nb_mistakes == 0).long()
+ nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
+
+ # print("\n\n", nb_correct, nb_wrong)
-# Compute the entropy of the training tokens
+ return nb_correct, nb_wrong
-token_count = 0
-for input in task.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- task.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- task.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+def identity_quizzes(quizzes):
+ quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
+ quizzes[:, 2] == quizzes[:, 3]
+ ).min(dim=1).values
-##############################
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
+ record = []
+ nb_validated = 0
-def one_epoch(model, task):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ start_time = time.perf_counter()
+ last_log = -1
- model.train()
+ while nb_validated < nb_to_generate:
+ # Generate new quizzes
- nb_train_samples, acc_train_loss = 0, 0.0
+ model = models[torch.randint(len(models), (1,)).item()]
+ model = copy.deepcopy(model).to(local_device).eval()
+ generator_id = model.id
- for input in task.batches(split="train"):
- input = input.to(device)
+ c_quizzes = ae_generate(
+ model=model, nb=args.eval_batch_size * 10, local_device=local_device
+ )
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
+ c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
+ if c_quizzes.size(0) > 0:
+ # Select the ones that are solved properly by some models and
+ # not understood by others
- nb_train_samples += input.size(0)
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=True,
+ local_device=local_device,
+ )
- loss.backward()
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
+ nb_validated += to_keep.long().sum().item()
+ record.append(c_quizzes[to_keep])
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ #####################
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ duration = time.perf_counter() - start_time
+ if last_log < 0 or duration > last_log + 10:
+ last_log = duration
+ if nb_validated > 0:
+ if nb_validated < nb_to_generate:
+ d = (nb_to_generate - nb_validated) * duration / nb_validated
+ e = (
+ datetime.datetime.now() + datetime.timedelta(seconds=d)
+ ).strftime("%a %H:%M")
+ else:
+ e = "now!"
+ else:
+ e = "???"
-######################################################################
+ log_string(
+ f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+ )
+ #####################
-def run_tests(model, task, deterministic_synthesis):
- with torch.autograd.no_grad():
- model.eval()
+ duration = time.perf_counter() - start_time
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
+ log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
- for input in task.batches(split="test"):
- input = input.to(device)
+ return torch.cat(record).to("cpu")
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
- loss = F.cross_entropy(output.transpose(1, 2), input)
+######################################################################
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+def multithread_execution(fun, arguments):
+ # Single instance, no thread
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
- main_test_accuracy = task.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- logger=log_string,
- deterministic_synthesis=deterministic_synthesis,
- )
+ records, threads = [], []
+
+ def threadable_fun(*args):
+ r = fun(*args)
+ if type(r) is not tuple:
+ r = (r,)
+ records.append(r)
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+ for args in arguments:
+ # To get a different sequence between threads
+ log_string(f"dummy_rand {torch.rand(1)}")
+ # torch.rand(1)
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+ threads.append(t)
+ t.start()
- log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+ for t in threads:
+ t.join()
- model.main_test_accuracy = main_test_accuracy
+ if records[0] == (None,):
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
######################################################################
-def create_quizzes(
- model,
- other_models,
- task,
- nb_for_train=1000,
- nb_for_test=100,
-):
- kept = []
-
- while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
- new_quizzes, nb_correct = task.create_new_quizzes(
- n_epoch=n_epoch,
- result_dir=args.result_dir,
- logger=log_string,
- nb=4 * (nb_for_train + nb_for_test),
- model=model,
- other_models=other_models,
+def save_models(models, suffix=""):
+ if suffix != "":
+ suffix = "_" + suffix
+
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
)
- to_keep = new_quizzes[nb_correct == len(other_models) - 1]
- log_string(f"keep {to_keep.size(0)} quizzes")
- kept.append(to_keep)
+ log_string(f"wrote ae_*{suffix}.pth")
- new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
- task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
- task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
+######################################################################
+
+
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
+ c_quizzes = c_quizzes.to(local_device)
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
+ )
+
+ comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
- task.save_image(
- new_quizzes[:96],
+ problem.save_quizzes_as_image(
args.result_dir,
- f"world_new_{n_epoch:04d}_{model.id:02d}.png",
- log_string,
+ filename,
+ quizzes=c_quizzes,
+ comments=comments,
+ delta=True,
+ nrow=8,
)
+ log_string(f"wrote {filename}")
+
+
+######################################################################
+
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
models = []
-for k in range(5):
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
+if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+
+for i in range(args.nb_models):
+ model = model_constructor(
+ vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
dropout=args.dropout,
- ).to(device)
+ )
+
+ # model = torch.compile(model)
- model.main_test_accuracy = 0.0
- model.id = k
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
models.append(model)
+######################################################################
+
+current_epoch = 0
+
+if args.resume:
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
+
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ log_string(f"successfully loaded {filename}")
+
+ filename = "state.pth"
+ state = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+
+ log_string(f"successfully loaded {filename}")
+
+ current_epoch = state["current_epoch"]
+ train_c_quizzes = state["train_c_quizzes"]
+ test_c_quizzes = state["test_c_quizzes"]
+
+######################################################################
nb_parameters = sum(p.numel() for p in models[0].parameters())
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
######################################################################
-accuracy_to_make_quizzes = 0.975
+train_c_quizzes, test_c_quizzes = None, None
-for n_epoch in range(args.nb_epochs):
- # select the model with lowest accuracy
- models.sort(key=lambda model: model.main_test_accuracy)
- model = models[0]
+######################################################################
- log_string(
- f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
- )
+for n_epoch in range(current_epoch, args.nb_epochs):
+ start_time = time.perf_counter()
- # improve it
- one_epoch(model, task)
+ state = {
+ "current_epoch": n_epoch,
+ "train_c_quizzes": train_c_quizzes,
+ "test_c_quizzes": test_c_quizzes,
+ }
- log_string(
- f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
- )
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
+ log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
+ log_string(f"current_test_accuracies {cta}")
+
+ # --------------------------------------------------------------------
- # test it
- run_tests(model, task, deterministic_synthesis=False)
+ lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
- if model.main_test_accuracy >= accuracy_to_make_quizzes:
- other_models = models.copy()
- other_models.remove(model)
+ if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
+ if train_c_quizzes is None:
+ save_models(models, "naive")
- create_quizzes(
- model,
- other_models,
- task,
- nb_for_train=1000,
- nb_for_test=100,
+ nb_gpus = len(gpus)
+ nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
+
+ (new_c_quizzes,) = multithread_execution(
+ generate_c_quizzes,
+ [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+ )
+
+ save_quiz_image(
+ models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
)
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
-######################################################################
+ train_c_quizzes = (
+ new_c_quizzes
+ if train_c_quizzes is None
+ else torch.cat([train_c_quizzes, new_c_quizzes])
+ )
+ train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
+
+ nb_correct, _ = evaluate_quizzes(
+ quizzes=train_c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
+ )
+
+ test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
+
+ for model in models:
+ model.test_accuracy = 0
+
+ if train_c_quizzes is None:
+ log_string("no_c_quiz")
+ else:
+ log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
+
+ # --------------------------------------------------------------------
+
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ weakest_models = ranked_models[: len(gpus)]
+
+ log_string(
+ f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+ )
+
+ multithread_execution(
+ one_complete_epoch,
+ [
+ (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
+ for model, gpu in zip(weakest_models, gpus)
+ ],
+ )
+
+ save_models(models)
+
+ # --------------------------------------------------------------------
+
+ duration = time.perf_counter() - start_time
+ str_duration = ""
+ if duration >= 60:
+ str_duration += f"{int(duration)//60}min"
+ str_duration += f"{int(duration)%60}s"
+ str_next = (
+ datetime.datetime.now() + datetime.timedelta(seconds=duration)
+ ).strftime("%H:%M:%S")
+ log_string(f"epoch_duration {str_duration} next_finish {str_next}")