Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 06:09:58 +0000 (08:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 24 Jul 2024 06:09:58 +0000 (08:09 +0200)
main.py

diff --git a/main.py b/main.py
index 58b4f18..122dd31 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,8 +101,6 @@ parser.add_argument("--temperature_cold", type=float, default=0.75)
 
 parser.add_argument("--nb_rounds", type=int, default=3)
 
-parser.add_argument("--noise_level", type=float, default=0)
-
 parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
 
 parser.add_argument("--p2a_only", action="store_true", default=False)
@@ -388,16 +386,6 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
         targets = input
 
-        if args.noise_level > 0:
-            m = (
-                (torch.rand(targets.size(), device=targets.device) < args.noise_level)
-                & (targets != quiz_machine.problem.token_forward)
-                & (targets != quiz_machine.problem.token_backward)
-            ).long()
-            input = (1 - m) * input.clone() + m * torch.randint(
-                vocabulary_size, input.size(), device=input.device
-            )
-
         output = model(mygpt.BracketedSequence(input)).x
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"