+def one_epoch(model, quiz_machine, local_device=None):
+ if local_device is None:
+ local_device = device
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ model.to(local_device).train()
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for input in quiz_machine.batches(model, split="train"):
+ input = input.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+
+ run_tests(model, quiz_machine, deterministic_synthesis=False)
+
+ model.TRAINING_LOCK.release()
+
+