From 716c1ecbd57e85b60d8ed02ee9d583e0a925e5e3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Aug 2024 22:23:32 +0200 Subject: [PATCH] Update. --- main.py | 58 +++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index b6c62cf..f28c10a 100755 --- a/main.py +++ b/main.py @@ -871,7 +871,12 @@ def test_ae(local_device=main_device): model.optimizer.zero_grad() targets = input - input = (mask_generate == 0).long() * input + + input = (mask_generate == 0).long() * input + ( + 1 - (mask_generate == 0).long() + ) * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) @@ -915,7 +920,13 @@ def test_ae(local_device=main_device): mask_loss = mask_loss.to(local_device) targets = input - input = (mask_generate == 0).long() * input + + input = (mask_generate == 0).long() * input + ( + 1 - (mask_generate == 0).long() + ) * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_test_loss += loss.item() * input.size(0) @@ -928,15 +939,38 @@ def test_ae(local_device=main_device): mask_generate = mask_generate.to(local_device) mask_loss = mask_loss.to(local_device) targets = input - input = (mask_generate == 0).long() * input - logits = model(mygpt.BracketedSequence(input)).x - dist = torch.distributions.categorical.Categorical(logits=logits) - result = dist.sample() + + pred_result = None + frozzen = None + + result = (mask_generate == 0).long() * input + ( + 1 - (mask_generate == 0).long() + ) * torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + + i = torch.full((result.size(0),), True, device=result.device) + + nb_it = 0 + L = input.size(1) // 4 - result[:, 0 * L] = input[:, 0 * L] - result[:, 1 * L] = input[:, 1 * L] - result[:, 2 * L] = input[:, 2 * L] - result[:, 3 * L] = input[:, 3 * L] + + while True: + logits = model(mygpt.BracketedSequence(result)).x + dist = torch.distributions.categorical.Categorical(logits=logits) + pred_result = result.clone() + result[i] = dist.sample()[i] + result[:, 0 * L] = input[:, 0 * L] + result[:, 1 * L] = input[:, 1 * L] + result[:, 2 * L] = input[:, 2 * L] + result[:, 3 * L] = input[:, 3 * L] + changed = (pred_result == result).long().min(dim=1).values == 0 + i = i & changed + nb_it += 1 + print("DEBUG", nb_it, i.long().sum().item()) + if not i.any() or nb_it > 100: + break + correct = (result == targets).min(dim=1).values.long() predicted_parts = input.new(input.size(0), 4) @@ -958,8 +992,8 @@ def test_ae(local_device=main_device): nb_correct = (correct == 1).long().sum() nb_total = (correct != 0).long().sum() - self.logger( - f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" + log_string( + f"test_accuracy {n_epoch} model AE {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) correct_parts = predicted_parts * correct[:, None] -- 2.39.5