From 221ad2d2f3f1705a927220f4782366b150f4ed54 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 17 Sep 2024 09:30:35 +0200 Subject: [PATCH] Update. --- main.py | 98 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/main.py b/main.py index e921ccd..3339838 100755 --- a/main.py +++ b/main.py @@ -475,6 +475,11 @@ def NTC_masked_cross_entropy(output, targets, mask): return (loss_per_token * mask).mean() +def masked_cross_entropy(output, targets, masks): + loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") + return (loss_per_token * masks).sum() / masks.expand_as(loss_per_token).sum() + + ###################################################################### @@ -506,7 +511,7 @@ def run_test( x_0=x_0, mask_generate=mask_generate, ) - loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) + loss = masked_cross_entropy(logits, x_0, mask_generate) acc_test_loss += loss.item() * x_0.size(0) nb_test_samples += x_0.size(0) @@ -600,7 +605,7 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device): nb_hints=nb_hints, ) - loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) + loss = masked_cross_entropy(logits, x_0, mask_generate) acc_train_loss += loss.item() * x_0.size(0) nb_train_samples += x_0.size(0) @@ -638,47 +643,43 @@ def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device): ###################################################################### -def batch_prediction(input, proba_hints=0.0): +def IMT_batch_prediction(input, proba_hints=0.0): nb = input.size(0) - mask = input.new_zeros(input.size()) - u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4) - mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] + masks = input.new_zeros(input.size()) + u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4) + masks.view(nb, 4, -1)[:, :, 1:] = u[:, :, None] if proba_hints > 0: - h = torch.rand(input.size(), device=input.device) * mask + h = torch.rand(input.size(), device=input.device) * masks mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints v = torch.rand(nb, device=input.device)[:, None] mask_hints = mask_hints * (v < proba_hints).long() - mask = (1 - mask_hints) * mask + masks = (1 - mask_hints) * masks # noise = quiz_machine.problem.pure_noise(nb, input.device) targets = input - input = (1 - mask) * targets # + mask * noise + input = (1 - masks) * targets # + masks * noise - return input, targets, mask + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -def predict(model, input, targets, mask, local_device=main_device): +def predict(model, imt_set, local_device=main_device): model.eval().to(local_device) - input_batches = input.reshape(-1, args.physical_batch_size, input.size(1)) - targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1)) - mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1)) - record = [] - for input, targets, mask in tqdm.tqdm( - zip(input_batches, targets_batches, mask_batches), + for imt in tqdm.tqdm( + imt_set.split(args.physical_batch_size), dynamic_ncols=True, desc="predict", - total=input.size(0) // args.physical_batch_size, + total=imt_set.size(0) // args.physical_batch_size, ): - # noise = quiz_machine.problem.pure_noise(input.size(0), input.device) - input = (1 - mask) * input # + mask * noise + masks = imt[:, 1] + imt = imt * (1 - masks[:, None]) # paranoia with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(NTC_channel_cat(input, mask)) + logits = model(imt[:, 0] * 2 + imt[:, 1]) dist = torch.distributions.categorical.Categorical(logits=logits) - result = (1 - mask) * input + mask * dist.sample() + result = (1 - masks) * imt[:, 0] + masks * dist.sample() record.append(result) return torch.cat(record) @@ -686,8 +687,10 @@ def predict(model, input, targets, mask, local_device=main_device): ###################################################################### +# IMT for input / masks / target + -def batch_generation(input): +def IMT_batch_generation(input): nb = input.size(0) probs_iterations = 0.1 ** torch.linspace( 0, 1, args.diffusion_nb_iterations, device=input.device @@ -704,10 +707,10 @@ def batch_generation(input): targets = input input = (1 - mask_erased) * input + mask_erased * noise - mask = input.new_full(input.size(), 1) - mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0 + masks = input.new_full(input.size(), 1) + masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0 - return input, targets, mask + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) def prioritized_rand(low): @@ -721,18 +724,18 @@ def prioritized_rand(low): def generate(model, nb, local_device=main_device): input = quiz_machine.problem.pure_noise(nb, local_device) - mask = input.new_full(input.size(), 1) - mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0 + masks = input.new_full(input.size(), 1) + masks.reshape(masks.size(0), 4, -1)[:, :, 0] = 0 changed = True for it in range(args.diffusion_nb_iterations): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(NTC_channel_cat(input, mask)) + logits = model(input) dist = torch.distributions.categorical.Categorical(logits=logits) output = dist.sample() r = prioritized_rand(input != output) - mask_changes = (r <= args.diffusion_proba_corruption).long() * mask + mask_changes = (r <= args.diffusion_proba_corruption).long() * masks update = (1 - mask_changes) * input + mask_changes * output if update.equal(input): break @@ -768,16 +771,14 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): ) input_p, input_g = quizzes.to(local_device).chunk(2) - input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5) - input_g, targets_g, mask_g = batch_generation(input_g) + imt_set = torch.cat( + [IMT_batch_prediction(input_p, proba_hints=0.5), IMT_batch_generation(input_g)] + ) - perm = torch.randperm(quizzes.size(0), device=local_device) - input_batches = batch_interleave(input_p, input_g, perm) - targets_batches = batch_interleave(targets_p, targets_g, perm) - mask_batches = batch_interleave(mask_p, mask_g, perm) + imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] - for input, targets, mask in tqdm.tqdm( - zip(input_batches, targets_batches, mask_batches), + for imt in tqdm.tqdm( + imt_set.split(args.physical_batch_size), dynamic_ncols=True, desc=label, total=quizzes.size(0) // args.physical_batch_size, @@ -786,11 +787,11 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): model.optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(NTC_channel_cat(input, mask)) + logits = model(imt[:, 0] * 2 + imt[:, 1]) - loss = NTC_masked_cross_entropy(logits, targets, mask) - acc_loss += loss.item() * input.size(0) - nb_samples += input.size(0) + loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1]) + acc_loss += loss.item() * imt.size(0) + nb_samples += imt.size(0) if train: loss.backward() @@ -810,11 +811,11 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): # predict quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) - input, targets, mask = batch_prediction(quizzes.to(local_device)) - result = predict(model, input, targets, mask, local_device=local_device).to("cpu") - mask = mask.to("cpu") + imt_set = IMT_batch_prediction(quizzes.to(local_device)) + result = predict(model, imt_set, local_device=local_device).to("cpu") + masks = imt_set[:, 1].to("cpu") correct = (quizzes == result).min(dim=1).values.long() - correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[ + correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[ :, :, 1 ] predicted_parts = correct_parts.abs() @@ -845,9 +846,8 @@ import attae models = [] for i in range(args.nb_models): - # model = MyAttentionAE( - model = attae.MaskedAttentionAE( - vocabulary_size=vocabulary_size, + model = attae.AttentionAE( + vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, dim_keys=args.dim_keys, dim_hidden=args.dim_hidden, -- 2.39.5