parser.add_argument("--nb_rounds", type=int, default=3)
-parser.add_argument("--noise_level", type=float, default=0)
-
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
parser.add_argument("--p2a_only", action="store_true", default=False)
targets = input
- if args.noise_level > 0:
- m = (
- (torch.rand(targets.size(), device=targets.device) < args.noise_level)
- & (targets != quiz_machine.problem.token_forward)
- & (targets != quiz_machine.problem.token_backward)
- ).long()
- input = (1 - m) * input.clone() + m * torch.randint(
- vocabulary_size, input.size(), device=input.device
- )
-
output = model(mygpt.BracketedSequence(input)).x
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"