From: François Fleuret Date: Wed, 24 Jul 2024 06:09:58 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1f93fc3b62588370707283c08561d2af9acb9c26;p=culture.git Update. --- diff --git a/main.py b/main.py index 58b4f18..122dd31 100755 --- a/main.py +++ b/main.py @@ -101,8 +101,6 @@ parser.add_argument("--temperature_cold", type=float, default=0.75) parser.add_argument("--nb_rounds", type=int, default=3) -parser.add_argument("--noise_level", type=float, default=0) - parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") parser.add_argument("--p2a_only", action="store_true", default=False) @@ -388,16 +386,6 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - if args.noise_level > 0: - m = ( - (torch.rand(targets.size(), device=targets.device) < args.noise_level) - & (targets != quiz_machine.problem.token_forward) - & (targets != quiz_machine.problem.token_backward) - ).long() - input = (1 - m) * input.clone() + m * torch.randint( - vocabulary_size, input.size(), device=input.device - ) - output = model(mygpt.BracketedSequence(input)).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none"