From 06b5009eb6b196402304636da1d068968c29cda7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 15:49:23 +0200 Subject: [PATCH] Update. --- main.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index 6b137bf..77976a4 100755 --- a/main.py +++ b/main.py @@ -372,24 +372,18 @@ def masked_cross_entropy(output, targets, masks): ###################################################################### +# Make args.nb_hints holes in the mask and copy the corresponding cell +# values from the target to the input -def add_hints_(imt_set): - input, masks, targets = imt_set + +def add_hints(imt_set): + input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] h = torch.rand(masks.size(), device=masks.device) - masks t = h.sort(dim=1).values[:, args.nb_hints, None] mask_hints = (h < t).long() - masks[...] = (1 - mask_hints) * masks - input[...] = (1 - mask_hints) * input + mask_hints * targets - - -def add_hints(masks, fraction_with_hints): - if fraction_with_hints > 0: - h = torch.rand(masks.size(), device=masks.device) - masks - t = h.sort(dim=1).values[:, args.nb_hints, None] - mask_hints = (h < t).long() - return (1 - mask_hints) * masks - else: - return masks + masks = (1 - mask_hints) * masks + input = (1 - mask_hints) * input + mask_hints * targets + return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) # IMT for input / masks / target @@ -435,7 +429,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full(model, input, local_device=main_device): +def predict_full(model, input, with_hints=False, local_device=main_device): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) @@ -445,6 +439,9 @@ def predict_full(model, input, local_device=main_device): input = (1 - masks) * targets imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) + if with_hints: + imt_set = add_hints(imt_set) + result = ae_predict(model, imt_set, local_device=local_device, desc=None) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) @@ -533,15 +530,12 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): args.c_quiz_multiplier, ) - q1, q2 = quizzes.to(local_device).chunk(2) - - imt_set = torch.cat( - [ - batch_for_prediction_imt(q1), - batch_for_generation_imt(q2), - ] - ) - + q_p, q_g = quizzes.to(local_device).chunk(2) + b_p = batch_for_prediction_imt(q_p) + i = torch.rand(b_p.size(0)) < 0.5 + b_p[i] = add_hints(b_p[i]) + b_g = batch_for_generation_imt(q_g) + imt_set = torch.cat([b_p, b_g]) imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)] if train: @@ -679,7 +673,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): result = predict_full( model=model, input=quizzes, - fraction_with_hints=fraction_with_hints, + with_hints=True, local_device=local_device, ) nb_mistakes = (result != quizzes).long().sum(dim=1) -- 2.39.5