From 4be8de19ab4a619c5762b77c369233d63638d445 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 16 Sep 2024 22:16:15 +0200 Subject: [PATCH] Update. --- grids.py | 10 +++ main.py | 210 +++++++++++++++++++++++++++++++++++++++++++++--- quiz_machine.py | 31 +++++++ 3 files changed, 238 insertions(+), 13 deletions(-) diff --git a/grids.py b/grids.py index 882c113..490750b 100755 --- a/grids.py +++ b/grids.py @@ -181,6 +181,16 @@ class Grids(problem.Problem): return quizzes + def pure_noise(self, nb, device): + result = torch.randint( + self.nb_colors, (nb, 4 * (self.height * self.height + 1)), device=device + ) + result.view(nb, 4, -1)[:, 0, 0] = self.token_A + result.view(nb, 4, -1)[:, 1, 0] = self.token_f_A + result.view(nb, 4, -1)[:, 2, 0] = self.token_B + result.view(nb, 4, -1)[:, 3, 0] = self.token_f_B + return result + # What a mess def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")): if torch.is_tensor(quizzes): diff --git a/main.py b/main.py index 0d46aa2..649c889 100755 --- a/main.py +++ b/main.py @@ -69,10 +69,6 @@ parser.add_argument("--nb_test_alien_samples", type=int, default=0) parser.add_argument("--nb_c_quizzes", type=int, default=10000) -parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) - -parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) - parser.add_argument("--c_quiz_multiplier", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=5e-4) @@ -115,9 +111,9 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_models", type=int, default=5) -parser.add_argument("--nb_diffusion_iterations", type=int, default=25) +parser.add_argument("--diffusion_nb_iterations", type=int, default=25) -parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05) +parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05) parser.add_argument("--min_succeed_to_validate", type=int, default=2) @@ -336,7 +332,7 @@ def mu_T_sampler(shape, device="cpu"): diffuser = diffusion.Diffuser( - mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption + mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption ) ###################################################################### @@ -470,6 +466,10 @@ def batches( ) +def NTC_channel_cat(*x): + return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) + + def NTC_masked_cross_entropy(output, targets, mask): loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") return (loss_per_token * mask).mean() @@ -567,7 +567,7 @@ def run_test( ###################################################################### -def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device): +def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device): model.train().to(local_device) optimizer_to(model.optimizer, local_device) @@ -635,6 +635,193 @@ def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device) ) +###################################################################### + + +def batch_prediction(input, proba_hints=0.0): + nb = input.size(0) + mask_generate = input.new_zeros(input.size()) + u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4) + mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + + if proba_hints > 0: + h = torch.rand(input.size(), device=input.device) * mask_generate + mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints + v = torch.rand(nb, device=input.device)[:, None] + mask_hints = mask_hints * (v < proba_hints).long() + mask_generate = (1 - mask_hints) * mask_generate + + # noise = quiz_machine.problem.pure_noise(nb, input.device) + targets = input + input = (1 - mask_generate) * targets # + mask_generate * noise + + return input, targets, mask_generate + + +def predict(model, quizzes, local_device=main_device): + model.eval().to(local_device) + + input, targets, mask = batch_prediction(quizzes.to(local_device)) + + input_batches = input.reshape(-1, args.physical_batch_size, input.size(1)) + targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1)) + mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1)) + + record = [] + + for input, targets, mask in tqdm.tqdm( + zip(input_batches, targets_batches, mask_batches), + dynamic_ncols=True, + desc="predict", + total=quizzes.size(0) // args.physical_batch_size, + ): + # noise = quiz_machine.problem.pure_noise(input.size(0), input.device) + input = (1 - mask) * input # + mask * noise + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(NTC_channel_cat(input, mask)) + dist = torch.distributions.categorical.Categorical(logits=logits) + result = (1 - mask) * input + mask * dist.sample() + record.append(result) + + return torch.cat(record) + + +###################################################################### + + +def batch_generation(input): + nb = input.size(0) + probs_iterations = 0.1 ** torch.linspace( + 0, 1, args.diffusion_nb_iterations, device=input.device + ) + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + probs_iterations = probs_iterations.expand(nb, -1) + dist = torch.distributions.categorical.Categorical(probs=probs_iterations) + t = dist.sample() + 1 + r = torch.rand(input.size(), device=input.device) + proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t + mask_erased = (r <= proba_erased[:, None]).long() + + noise = quiz_machine.problem.pure_noise(nb, input.device) + + targets = input + input = (1 - mask_erased) * input + mask_erased * noise + mask_generate = input.new_full(input.size(), 1) + mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0 + + return input, targets, mask_generate + + +def prioritized_rand(low): + x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values + k = torch.rand(low.size(), device=low.device) + low.long() + k = k.sort(dim=1).indices + y = x.new(x.size()) + y.scatter_(dim=1, index=k, src=x) + return y + + +def generate(model, nb, local_device=main_device): + input = quiz_machine.problem.pure_noise(nb, local_device) + mask_generate = input.new_full(input.size(), 1) + mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0 + + changed = True + for it in range(self.diffusion_nb_iterations): + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(NTC_channel_cat(input, mask_generate)) + dist = torch.distributions.categorical.Categorical(logits=logits) + output = dist.sample() + + r = self.prioritized_rand(input != output) + mask_changes = (r <= self.proba_corruption).long() + update = (1 - mask_changes) * input + mask_changes * output + + if update.equal(input): + break + else: + changed = changed & (update != input).max(dim=1).values + input[changed] = update[changed] + + return input + + +###################################################################### + + +def batch_interleave(a, b, perm): + return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1)) + + +def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): + if train: + label = "train" + model.train().to(local_device) + optimizer_to(model.optimizer, local_device) + else: + label = "test" + model.eval().to(local_device) + + nb_samples, acc_loss = 0, 0.0 + + quizzes = quiz_machine.quiz_set( + args.nb_train_samples if train else args.nb_test_samples, + c_quizzes, + args.c_quiz_multiplier, + ) + + input_p, input_g = quizzes.to(local_device).chunk(2) + input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5) + input_g, targets_g, mask_g = batch_generation(input_g) + + perm = torch.randperm(quizzes.size(0), device=local_device) + input_batches = batch_interleave(input_p, input_g, perm) + targets_batches = batch_interleave(targets_p, targets_g, perm) + mask_batches = batch_interleave(mask_p, mask_g, perm) + + for input, targets, mask in tqdm.tqdm( + zip(input_batches, targets_batches, mask_batches), + dynamic_ncols=True, + desc=label, + total=quizzes.size(0) // args.physical_batch_size, + ): + if train and nb_samples % args.batch_size == 0: + model.optimizer.zero_grad() + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(NTC_channel_cat(input, mask)) + + loss = NTC_masked_cross_entropy(logits, targets, mask) + acc_loss += loss.item() * input.size(0) + nb_samples += input.size(0) + + if train: + loss.backward() + + if nb_samples % args.batch_size == 0: + model.optimizer.step() + + log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}") + + +def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): + one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True) + + one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False) + + quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) + result = predict(model, quizzes).to("cpu") + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"culture_prediction_{n_epoch}_{model.id}.png", + quizzes=result[:128], + ) + + nb_correct = (quizzes == result).min(dim=1).values.long().sum() + model.test_accuracy = nb_correct / quizzes.size(0) + + ###################################################################### import attae @@ -1099,11 +1286,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], multithread_execution( - one_epoch, - [ - (model, quiz_machine, n_epoch, c_quizzes, gpu) - for model, gpu in zip(weakest_models, gpus) - ], + one_train_test_epoch, + [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)], ) # -------------------------------------------------------------------- diff --git a/quiz_machine.py b/quiz_machine.py index f1eb9db..781c1cf 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -195,6 +195,37 @@ class QuizMachine: ###################################################################### + def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1): + if c_quizzes is None: + quizzes = self.problem.generate_w_quizzes(nb_samples) + else: + if c_quiz_multiplier > 1: + n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) + body = c_quizzes.repeat(n, 1) + if n < c_quiz_multiplier: + tail = c_quizzes[ + torch.randperm(c_quizzes.size(0))[ + : nb_samples // 2 - body.size(0) + ] + ] + c_quizzes = torch.cat([body, tail], dim=0) + else: + c_quizzes = body + + if c_quizzes.size(0) > nb_samples // 2: + i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] + c_quizzes = c_quizzes[i] + + w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes = quizzes[i].contiguous() + + return quizzes + + ###################################################################### + def make_quiz_mask(self, quizzes, quad_order, quad_mask): assert quad_order in [s for s, _, _, _ in self.train_structures] return self.problem.make_quiz_mask( -- 2.39.5