From aa56600deac5de543974886b5a34b634f194f92a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 16:05:48 +0200 Subject: [PATCH] Update. --- main.py | 50 +++++++++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index 77976a4..d8dffe2 100755 --- a/main.py +++ b/main.py @@ -351,26 +351,6 @@ def optimizer_to(optim, device): ###################################################################### -# quad_order, quad_generate, quad_noise, quad_loss - -data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), - (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), -] - - -###################################################################### - - -def masked_cross_entropy(output, targets, masks): - loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") - return (loss_per_token * masks).mean() - - -###################################################################### # Make args.nb_hints holes in the mask and copy the corresponding cell # values from the target to the input @@ -386,6 +366,20 @@ def add_hints(imt_set): return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) +# Make pixels from the available input (mask=0) noise with probability +# args.prompt_noise + + +def add_noise(imt_set): + input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] + noise = quiz_machine.pure_noise(input.size(0), input.device) + change = (1 - masks) * ( + torch.rand(input.size(), device=input.device) < args.prompt_noise + ).long() + input = (1 - change) * input + change * noise + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) + + # IMT for input / masks / target @@ -429,7 +423,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full(model, input, with_hints=False, local_device=main_device): +def predict_full(model, input, with_perturbations=False, local_device=main_device): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) @@ -439,8 +433,9 @@ def predict_full(model, input, with_hints=False, local_device=main_device): input = (1 - masks) * targets imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - if with_hints: + if with_perturbations: imt_set = add_hints(imt_set) + imt_set = add_noise(imt_set) result = ae_predict(model, imt_set, local_device=local_device, desc=None) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) @@ -533,6 +528,7 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): q_p, q_g = quizzes.to(local_device).chunk(2) b_p = batch_for_prediction_imt(q_p) i = torch.rand(b_p.size(0)) < 0.5 + b_p = add_noise(b_p) b_p[i] = add_hints(b_p[i]) b_g = batch_for_generation_imt(q_g) imt_set = torch.cat([b_p, b_g]) @@ -554,13 +550,17 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): desc=label, total=quizzes.size(0) // args.physical_batch_size, ): + input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2] if train and nb_samples % args.batch_size == 0: model.optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(imt[:, 0] * 2 + imt[:, 1]) + logits = model(input * 2 + masks) - loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1]) + loss_per_token = F.cross_entropy( + logits.transpose(1, 2), targets, reduction="none" + ) + loss = (loss_per_token * masks).mean() acc_loss += loss.item() * imt.size(0) nb_samples += imt.size(0) @@ -673,7 +673,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): result = predict_full( model=model, input=quizzes, - with_hints=True, + with_perturbations=True, local_device=local_device, ) nb_mistakes = (result != quizzes).long().sum(dim=1) -- 2.39.5