- if args.autoencoder_weight > 0:
- bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
- output_ar, output_ae = bs_ar.x, bs_ae.x
- loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
- loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1])
- else:
- output = model(mygpt.BracketedSequence(input)).x
- loss_ar = F.cross_entropy(output.transpose(1, 2), input)
- loss_ae = loss_ar.new_full((1,), 0.0)
-
- acc_train_loss_ar += loss_ar.item() * input.size(0)
- acc_train_loss_ae += loss_ae.item() * input.size(0)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)