- output = model(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ if args.autoencoder_weight > 0:
+ bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
+ output_ar, output_ae = bs_ar.x, bs_ae.x
+ loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
+ loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1])
+ else:
+ output = model(mygpt.BracketedSequence(input)).x
+ loss_ar = F.cross_entropy(output.transpose(1, 2), input)
+ loss_ae = loss_ar.new_full((1,), 0.0)
+
+ acc_train_loss_ar += loss_ar.item() * input.size(0)
+ acc_train_loss_ae += loss_ae.item() * input.size(0)
+