From 1f93fc3b62588370707283c08561d2af9acb9c26 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 24 Jul 2024 08:09:58 +0200 Subject: [PATCH] Update. --- main.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/main.py b/main.py index 58b4f18..122dd31 100755 --- a/main.py +++ b/main.py @@ -101,8 +101,6 @@ parser.add_argument("--temperature_cold", type=float, default=0.75) parser.add_argument("--nb_rounds", type=int, default=3) -parser.add_argument("--noise_level", type=float, default=0) - parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--p2a_only", action="store_true", default=False) @@ -388,16 +386,6 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - if args.noise_level > 0: - m = ( - (torch.rand(targets.size(), device=targets.device) < args.noise_level) - & (targets != quiz_machine.problem.token_forward) - & (targets != quiz_machine.problem.token_backward) - ).long() - input = (1 - m) * input.clone() + m * torch.randint( - vocabulary_size, input.size(), device=input.device - ) - output = model(mygpt.BracketedSequence(input)).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" -- 2.39.5