model.train()
- nb_train_samples, acc_train_loss_ar, acc_train_loss_ae = 0, 0.0, 0.0
+ nb_train_samples, acc_train_loss = 0, 0.0
for input in task.batches(split="train"):
input = input.to(device)
if nb_train_samples % args.batch_size == 0:
optimizer.zero_grad()
- if args.autoencoder_weight > 0:
- bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
- output_ar, output_ae = bs_ar.x, bs_ae.x
- loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
- loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1])
- else:
- output = model(mygpt.BracketedSequence(input)).x
- loss_ar = F.cross_entropy(output.transpose(1, 2), input)
- loss_ae = loss_ar.new_full((1,), 0.0)
-
- acc_train_loss_ar += loss_ar.item() * input.size(0)
- acc_train_loss_ae += loss_ae.item() * input.size(0)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
- (loss_ar + args.autoencoder_weight * loss_ae).backward()
+ loss.backward()
if nb_train_samples % args.batch_size == 0:
optimizer.step()
with torch.autograd.no_grad():
model.eval()
- nb_test_samples, acc_test_loss_ar, acc_test_loss_ae = 0, 0.0, 0.0
+ nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
for input in task.batches(split="test"):
input = input.to(device)
- if args.autoencoder_weight > 0:
- bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
- output_ar, output_ae = bs_ar.x, bs_ae.x
- loss_ae = F.cross_entropy(
- output_ae[:, 1:].transpose(1, 2), input[:, :-1]
- )
- acc_test_loss_ae += loss_ae.item() * input.size(0)
- else:
- bs_ar = model(mygpt.BracketedSequence(input))
- output_ar = bs_ar.x
+ bs = model(mygpt.BracketedSequence(input))
+ output_ar = bs.x
- loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
+ loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_test_loss_ar += loss_ar.item() * input.size(0)
+ acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
- train_ar_perplexity = math.exp(min(100, acc_train_loss_ar / nb_train_samples))
- test_ar_perplexity = math.exp(min(100, acc_test_loss_ar / nb_test_samples))
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
log_string(
- f"perplexity_ar {n_epoch} train_set {train_set_perplexity} train_prediction {train_ar_perplexity} test_prediction {test_ar_perplexity}"
+ f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
)
- if args.autoencoder_weight > 0:
- train_ae_perplexity = math.exp(
- min(100, acc_train_loss_ae / nb_train_samples)
- )
- test_ae_perplexity = math.exp(min(100, acc_test_loss_ae / nb_test_samples))
-
- log_string(
- f"perplexity_ae {n_epoch} train_set {train_set_perplexity} train_prediction {train_ae_perplexity} test_prediction {test_ae_perplexity}"
- )
-
task.produce_results(
n_epoch=n_epoch,
model=model,