# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, argparse, time, tqdm, os, datetime, warnings
+import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
import torch, torchvision
from torch import nn
from torch.nn import functional as F
-import ffutils
+import ffutils, grids, attae
-import mygpt
-import sky, grids, quiz_machine
+import threading, subprocess
-import threading
+# import torch.multiprocessing as mp
-import torch.multiprocessing as mp
+torch.set_float32_matmul_precision("high")
+
+# torch.set_default_dtype(torch.bfloat16)
######################################################################
parser.add_argument("--resume", action="store_true", default=False)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
-
-########################################
+# ----------------------------------
parser.add_argument("--nb_epochs", type=int, default=10000)
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--train_batch_size", type=int, default=None)
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--eval_batch_size", type=int, default=25)
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
-parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+parser.add_argument("--nb_c_quizzes", type=int, default=5000)
+
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=5e-4)
-########################################
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
+# ----------------------------------
+
+parser.add_argument("--model_type", type=str, default="standard")
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--model", type=str, default="37M")
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--nb_blocks", type=int, default=None)
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
-########################################
-
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="grids")
+# ----------------------------------
parser.add_argument("--nb_threads", type=int, default=1)
parser.add_argument("--gpus", type=str, default="all")
-parser.add_argument("--nb_gpts", type=int, default=5)
+# ----------------------------------
+
+parser.add_argument("--nb_models", type=int, default=5)
+
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
-parser.add_argument("--proba_understands", type=float, default=0.99)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--proba_hint", type=float, default=0.25)
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument("--quizzes", type=str, default=None)
######################################################################
)
parser.add_argument(
- "--grids_tasks",
+ "--grids_world_tasks",
type=str,
- default=None,
- help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+ default="replace_color,translate,grow,frame",
+ help="A comma-separated subset of: " + grids_tasks + ".",
)
######################################################################
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
-
-parser.add_argument("--sky_speed", type=int, default=3)
-
-######################################################################
-
args = parser.parse_args()
if args.result_dir is None:
######################################################################
-default_args = {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 100000,
- "nb_test_samples": 10000,
-}
-
-for k, v in default_args.items():
- if getattr(args, k) is None:
- setattr(args, k, v)
-
-######################################################################
-
default_model_args = {
"17K": {
"dim_model": 32,
######################################################################
if args.resume:
- assert os.path.isdir(args.result_dir)
-
+ if not os.path.isdir(args.result_dir):
+ print(f"Trying to resume from a non-existing result dir {args.result_dir}.")
+ exit(1)
else:
try:
os.mkdir(args.result_dir)
def log_string(s):
+ """print the given string prefixed with a time stamps, and log it
+ into log_file is not None"""
+
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
if log_file is not None:
sys.stdout.flush()
+######################################################################
+# Create a time-stamped archive of the source code
+
+with open("this_run.sh", "w") as f:
+ f.write(f"{' '.join(sys.argv)}\n")
+
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+
+######################################################################
log_string(f"argv {' '.join(sys.argv)}")
assert len(gpus) == 0
main_device = torch.device("cpu")
-if args.dirty_debug:
- args.nb_train_samples = 2500
- args.nb_test_samples = 100
-
-if args.physical_batch_size is None:
- args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+ args.train_batch_size = args.batch_size
else:
- assert args.batch_size % args.physical_batch_size == 0
+ assert args.batch_size % args.train_batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.problem == "sky":
- problem = sky.Sky(
- height=args.sky_height,
- width=args.sky_width,
- nb_birds=args.sky_nb_birds,
- nb_iterations=args.sky_nb_iterations,
- speed=args.sky_speed,
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- )
- back_accuracy = False
-elif args.problem == "grids":
- problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks=args.grids_tasks,
- )
- back_accuracy = True
-else:
- raise ValueError
-
-problem.save_some_examples(args.result_dir)
-
-quiz_machine = quiz_machine.QuizMachine(
- problem=problem,
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- back_accuracy=back_accuracy,
- batch_size=args.physical_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=main_device,
-)
+######################################################################
+
+
+def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
######################################################################
-log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
-vocabulary_size = quiz_machine.vocabulary_size()
+def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+ if c_quizzes is None:
+ quizzes = problem.generate_w_quizzes(nb_samples)
+ nb_w_quizzes = quizzes.size(0)
+ nb_c_quizzes = 0
+ else:
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
+
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
+
+ w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ nb_w_quizzes = w_quizzes.size(0)
+ nb_c_quizzes = c_quizzes.size(0)
+
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes = quizzes[i].contiguous()
+
+ log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+ return quizzes
-log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
-def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
- with torch.autograd.no_grad():
- model.eval().to(local_device)
+def add_hints_imt(imt_set):
+ """Set every component of the mask to zero with probability
+ args.proba_hint, and for each component set to zero, copy the
+ corresponding value from the target into the input
+
+ """
+ input, masks, targets = imt_set.unbind(dim=1)
+ # h = torch.rand(masks.size(), device=masks.device) - masks
+ # t = h.sort(dim=1).values[:, args.nb_hints, None]
+ # mask_hints = (h < t).long()
+ mask_hints = (
+ torch.rand(input.size(), device=input.device) < args.proba_hint
+ ).long() * masks
+ masks = (1 - mask_hints) * masks
+ input = (1 - mask_hints) * input + mask_hints * targets
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def add_noise_imt(imt_set):
+ """Replace every component of the input by a random value with
+ probability args.proba_prompt_noise."""
+ input, masks, targets = imt_set.unbind(dim=1)
+ noise = problem.pure_noise(input.size(0), input.device)
+ change = (1 - masks) * (
+ torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
+ ).long()
+ input = (1 - change) * input + change * noise
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
+######################################################################
+# Prediction
- for input in quiz_machine.batches(model, split="test"):
- input = input.to(local_device)
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
+def samples_for_prediction_imt(input):
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
- loss = F.cross_entropy(output.transpose(1, 2), input)
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+def ae_predict(model, imt_set, local_device=main_device):
+ model.eval().to(local_device)
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+ record = []
- log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
+ src = tqdm.tqdm(
+ imt_set.split(args.eval_batch_size),
+ dynamic_ncols=True,
+ desc="predict",
+ total=imt_set.size(0) // args.eval_batch_size,
+ delay=10,
+ )
- model.main_test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
+ for imt in src:
+ # some paranoia
+ imt = imt.clone()
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(imt[:, 0] * 2 + imt[:, 1])
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
+ record.append(result)
+
+ return torch.cat(record)
+
+
+def predict_the_four_grids(
+ model, input, with_noise=False, with_hints=False, local_device=main_device
+):
+ input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
+ nb = input.size(0)
+ masks = input.new_zeros(input.size())
+ u = F.one_hot(torch.arange(nb, device=masks.device) % 4, num_classes=4)
+ masks.view(nb, 4, -1)[...] = u[:, :, None]
+ targets = input
+ input = (1 - masks) * targets
+ imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+ if with_hints:
+ imt_set = add_hints_imt(imt_set)
+
+ if with_noise:
+ imt_set = add_noise_imt(imt_set)
+
+ result = ae_predict(model, imt_set, local_device=local_device)
+ result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
+
+ return result
+
+
+######################################################################
+
+
+def samples_for_generation_imt(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+
+ noise = problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ masks = input.new_full(input.size(), 1)
+
+ return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def ae_generate(model, nb, local_device=main_device):
+ model.eval().to(local_device)
+
+ # We loop through the iterations first and through the
+ # mini-batches second so that we keep only the samples that have
+ # not stabilized
+
+ all_input = problem.pure_noise(nb, local_device)
+ all_masks = all_input.new_full(all_input.size(), 1)
+ all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
+
+ for it in range(args.diffusion_nb_iterations):
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
+
+ if not all_changed.any():
+ break
+
+ sub_input = all_input[all_changed].clone()
+ sub_masks = all_masks[all_changed].clone()
+ sub_changed = all_changed[all_changed].clone()
+
+ src = zip(
+ sub_input.split(args.eval_batch_size),
+ sub_masks.split(args.eval_batch_size),
+ sub_changed.split(args.eval_batch_size),
)
+ for input, masks, changed in src:
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * masks
+ update = (1 - mask_changes) * input + mask_changes * output
+ changed[...] = changed & (update != input).max(dim=1).values
+ input[...] = update
-def one_epoch(model, quiz_machine, local_device=main_device):
- model.to(local_device).train()
+ a = all_changed.clone()
+ all_input[a] = sub_input
+ all_masks[a] = sub_masks
+ all_changed[a] = sub_changed
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ return all_input
- nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(model, split="train"):
- input = input.to(local_device)
+######################################################################
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
+def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device):
+ quizzes = generate_quiz_set(
+ args.nb_train_samples if train else args.nb_test_samples,
+ c_quizzes,
+ args.c_quiz_multiplier,
+ )
+
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
+ b_p = samples_for_prediction_imt(q_p)
+ b_p = add_noise_imt(b_p)
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
+
+ # The other half are denoising examples for the generation
+ b_g = samples_for_generation_imt(q_g)
+
+ imt_set = torch.cat([b_p, b_g])
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
+ else:
+ label = "test"
+ model.eval().to(local_device)
+ batch_size = args.eval_batch_size
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for imt in tqdm.tqdm(
+ imt_set.split(batch_size),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // batch_size,
+ delay=10,
+ ):
+ input, masks, targets = imt.unbind(dim=1)
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * masks).mean()
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
+
+ if train:
+ loss.backward()
- nb_train_samples += input.size(0)
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
- loss.backward()
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+######################################################################
- log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
- run_tests(model, quiz_machine, deterministic_synthesis=False)
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+ # Save some images of the prediction results
- model.to(main_device)
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
+
+ correct = (quizzes == result).min(dim=1).values.long()
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+ :, :, 1
+ ]
+ predicted_parts = correct_parts.abs()
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
+ )
+
+ # Save some images of the ex nihilo generation of the four grids
+
+ result = ae_generate(model, 150, local_device=local_device).to("cpu")
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
######################################################################
-# This is the key routine that decides what generated quizzes to keep
+def one_complete_epoch(
+ model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+ one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
+
+ one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
+
+ # Compute the test accuracy
+
+ quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ correct = (quizzes == result).min(dim=1).values.long()
+
+ nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
+ model.test_accuracy = nb_correct / nb_total
-def compute_valid_quizzes(token_logprobas):
- warnings.warn("validation with uniform constraints", RuntimeWarning)
- l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
- return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
+ )
+
+ save_inference_images(
+ model, n_epoch, c_quizzes, args.c_quiz_multiplier, local_device=local_device
+ )
+
+
+######################################################################
-def compute_valid_quizzes_(token_logprobas):
- l = token_logprobas.sum(dim=-1).sort(dim=-1).values
- return (l[:, 0] < math.log(args.proba_not_understands)) & (
- l[:, 1] > math.log(args.proba_understands)
+def max_nb_mistakes_on_one_grid(quizzes, prediction):
+ return (
+ (prediction != quizzes)
+ .long()
+ .reshape(quizzes.size(0), 4, -1)
+ .sum(dim=2)
+ .max(dim=1)
+ .values
)
-def extract_valid_quizzes_and_logprobas(recorded):
- validated_quizzes, validated_logprobas = [], []
- for quizzes, token_logprobas in recorded:
- validated_indices = compute_valid_quizzes(token_logprobas)
- validated_quizzes.append(quizzes[validated_indices])
- validated_logprobas.append(token_logprobas[validated_indices])
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
+ nb_correct, nb_wrong = 0, 0
- if len(validated_quizzes) > 0:
- return torch.cat(validated_quizzes, dim=0), torch.cat(
- validated_logprobas, dim=0
+ for model in models:
+ model = copy.deepcopy(model).to(local_device).eval()
+ predicted = predict_the_four_grids(
+ model=model,
+ input=quizzes,
+ with_noise=False,
+ with_hints=with_hints,
+ local_device=local_device,
)
- else:
- return None, None
+ nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
+ nb_correct += (nb_mistakes == 0).long()
+ nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
+
+ # print("\n\n", nb_correct, nb_wrong)
+
+ return nb_correct, nb_wrong
######################################################################
-def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
- nb_to_create = nb_for_train + nb_for_test
+def identity_quizzes(quizzes):
+ quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+ return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | (
+ quizzes[:, 2] == quizzes[:, 3]
+ ).min(dim=1).values
- recorded_quizzes_logprobas = []
+def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
+ record = []
nb_validated = 0
- while nb_validated < nb_to_create:
- model_for_generation = models[torch.randint(len(models), (1,))]
+ start_time = time.perf_counter()
+ last_log = -1
- c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
- model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
+ while nb_validated < nb_to_generate:
+ # Generate new quizzes
+
+ model = models[torch.randint(len(models), (1,)).item()]
+ model = copy.deepcopy(model).to(local_device).eval()
+ generator_id = model.id
+
+ c_quizzes = ae_generate(
+ model=model, nb=args.eval_batch_size * 10, local_device=local_device
)
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
+ c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
if c_quizzes.size(0) > 0:
- token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
- recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
+ # Select the ones that are solved properly by some models and
+ # not understood by others
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=True,
+ local_device=local_device,
+ )
+
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
+
+ nb_validated += to_keep.long().sum().item()
+ record.append(c_quizzes[to_keep])
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ if last_log < 0 or duration > last_log + 10:
+ last_log = duration
+ if nb_validated > 0:
+ if nb_validated < nb_to_generate:
+ d = (nb_to_generate - nb_validated) * duration / nb_validated
+ e = (
+ datetime.datetime.now() + datetime.timedelta(seconds=d)
+ ).strftime("%a %H:%M")
+ else:
+ e = "now!"
+ else:
+ e = "???"
+
+ log_string(
+ f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+ )
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h")
+
+ return torch.cat(record).to("cpu")
+
+
+######################################################################
- (
- validated_quizzes,
- validated_logprobas,
- ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
- if validated_quizzes is not None:
- nb_validated = validated_quizzes.size(0)
+def multithread_execution(fun, arguments):
+ # Single instance, no thread
+ if len(arguments) == 1:
+ return fun(*(arguments[0]))
- log_string(
- f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+ records, threads = [], []
+
+ def threadable_fun(*args):
+ r = fun(*args)
+ if type(r) is not tuple:
+ r = (r,)
+ records.append(r)
+
+ for args in arguments:
+ # To get a different sequence between threads
+ log_string(f"dummy_rand {torch.rand(1)}")
+ # torch.rand(1)
+ t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+ threads.append(t)
+ t.start()
+
+ for t in threads:
+ t.join()
+
+ if records[0] == (None,):
+ return
+ else:
+ return [
+ torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+ ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+ if suffix != "":
+ suffix = "_" + suffix
+
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
)
- # store the new c_quizzes which have been validated
+ log_string(f"wrote ae_*{suffix}.pth")
+
+
+######################################################################
+
- quiz_machine.reverse_random_half_in_place(validated_quizzes)
- quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
- quiz_machine.store_c_quizzes(
- validated_quizzes[nb_for_train:nb_to_create], for_train=False
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
+ c_quizzes = c_quizzes.to(local_device)
+
+ nb_correct, nb_wrong = evaluate_quizzes(
+ quizzes=c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
+ )
+
+ comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=c_quizzes,
+ comments=comments,
+ delta=True,
+ nrow=8,
)
- ######################################################################
- # save images with their logprobas
+ log_string(f"wrote {filename}")
- vq = validated_quizzes[:72]
- vl = validated_logprobas[:72]
- if vq.size(0) > 0:
- prefix = f"culture_c_quiz_{n_epoch:04d}"
- filename = os.path.join(args.result_dir, prefix + "_logp.pth")
- torch.save(vl, filename)
- # with open(file_name, "w") as logp_file:
- # for l in vl:
- # s = " ".join([str(x.item()) for x in l])
- # logp_file.write(s + "\n")
+######################################################################
- quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
+if not args.resume:
+ problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
######################################################################
models = []
-for k in range(args.nb_gpts):
- log_string(f"creating model {k} and its w_quizzes")
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
+if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+
+for i in range(args.nb_models):
+ model = model_constructor(
+ vocabulary_size=vocabulary_size * 2,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,
nb_heads=args.nb_heads,
nb_blocks=args.nb_blocks,
- causal=True,
dropout=args.dropout,
- ).to(main_device)
+ )
- model.main_test_accuracy = 0.0
- model.id = k
+ # model = torch.compile(model)
- model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
- quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
- model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
- quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
models.append(model)
######################################################################
-if args.resume:
- try:
- for model in models:
- filename = f"gpt_{model.id:03d}.pth"
-
- try:
- d = torch.load(os.path.join(args.result_dir, filename))
- model.load_state_dict(d[0])
- model.main_test_accuracy = d[1]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- try:
- filename = "c_quizzes.pth"
- quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- except:
- log_string(f"error when loading {filename}.")
- exit(1)
+current_epoch = 0
-######################################################################
-
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+if args.resume:
+ for model in models:
+ filename = f"ae_{model.id:03d}.pth"
-######################################################################
+ d = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
+ model.load_state_dict(d["state_dict"])
+ model.optimizer.load_state_dict(d["optimizer_state_dict"])
+ model.test_accuracy = d["test_accuracy"]
+ log_string(f"successfully loaded {filename}")
+
+ filename = "state.pth"
+ state = torch.load(
+ os.path.join(args.result_dir, filename),
+ map_location="cpu",
+ weights_only=False,
+ )
-# Compute the entropy of the training tokens
+ log_string(f"successfully loaded {filename}")
-token_count = 0
-for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+ current_epoch = state["current_epoch"]
+ train_c_quizzes = state["train_c_quizzes"]
+ test_c_quizzes = state["test_c_quizzes"]
######################################################################
-# A bit of paranoia never hurts
-
-if args.max_percents_of_test_in_train >= 0:
-
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
-
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
######################################################################
-if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-
-if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
-
-log_string(
- f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
-)
+train_c_quizzes, test_c_quizzes = None, None
######################################################################
-if args.dirty_debug:
- args.accuracy_to_make_c_quizzes = 0.0
- args.nb_gpts = 2
- args.nb_new_c_quizzes_for_train = 100
- args.nb_new_c_quizzes_for_test = 10
+for n_epoch in range(current_epoch, args.nb_epochs):
+ start_time = time.perf_counter()
+ state = {
+ "current_epoch": n_epoch,
+ "train_c_quizzes": train_c_quizzes,
+ "test_c_quizzes": test_c_quizzes,
+ }
-######################################################################
+ filename = "state.pth"
+ torch.save(state, os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
-for n_epoch in range(args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
- cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+ cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
- ##################################################
- # If all the models are good enough, generate new quizzes and
- # re-compute the test errors
+ # --------------------------------------------------------------------
- if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=args.nb_new_c_quizzes_for_train,
- nb_for_test=args.nb_new_c_quizzes_for_test,
- )
+ lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
- filename = "c_quizzes.pth"
- quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
- log_string(f"wrote {filename}")
+ if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
+ if train_c_quizzes is None:
+ save_models(models, "naive")
- ##################################################
- # Select, improve, and eval the worst model
+ nb_gpus = len(gpus)
+ nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
- ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+ (new_c_quizzes,) = multithread_execution(
+ generate_c_quizzes,
+ [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+ )
- weakest_models = ranked_models[: len(gpus)]
+ save_quiz_image(
+ models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
+ )
- threads = []
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
- for gpu, model in zip(gpus, weakest_models):
- log_string(f"training model {model.id}")
+ train_c_quizzes = (
+ new_c_quizzes
+ if train_c_quizzes is None
+ else torch.cat([train_c_quizzes, new_c_quizzes])
+ )
+ train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
- t = threading.Thread(
- target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ nb_correct, _ = evaluate_quizzes(
+ quizzes=train_c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
)
- threads.append(t)
+ test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
- t.start()
+ for model in models:
+ model.test_accuracy = 0
- for t in threads:
- t.join()
+ if train_c_quizzes is None:
+ log_string("no_c_quiz")
+ else:
+ log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
- # Save the models to disk
+ # --------------------------------------------------------------------
- for model in weakest_models:
- filename = f"gpt_{model.id:03d}.pth"
- torch.save(
- (model.state_dict(), model.main_test_accuracy),
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ weakest_models = ranked_models[: len(gpus)]
+
+ log_string(
+ f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+ )
- # Renew the training samples
+ multithread_execution(
+ one_complete_epoch,
+ [
+ (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
+ for model, gpu in zip(weakest_models, gpus)
+ ],
+ )
- for model in weakest_models:
- quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
+ save_models(models)
+ # --------------------------------------------------------------------
-######################################################################
+ duration = time.perf_counter() - start_time
+ str_duration = ""
+ if duration >= 60:
+ str_duration += f"{int(duration)//60}min"
+ str_duration += f"{int(duration)%60}s"
+ str_next = (
+ datetime.datetime.now() + datetime.timedelta(seconds=duration)
+ ).strftime("%H:%M:%S")
+ log_string(f"epoch_duration {str_duration} next_finish {str_next}")