X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;fp=main.py;h=37515b5b26ddcbd952357cbc715ef6bc13969744;hb=d95b9b72b0f098b5c955395905a0aff710f553a7;hp=dace5f2e2276ff4ba2472b90677b253dea58a46e;hpb=798d9526e726b644979cf1124e714f705fdd5966;p=picoclvr.git diff --git a/main.py b/main.py index dace5f2..37515b5 100755 --- a/main.py +++ b/main.py @@ -815,7 +815,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): model.train() - nb_train_samples, acc_train_loss_ar, acc_train_loss_ae = 0, 0.0, 0.0 + nb_train_samples, acc_train_loss = 0, 0.0 for input in task.batches(split="train"): input = input.to(device) @@ -823,22 +823,13 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): if nb_train_samples % args.batch_size == 0: optimizer.zero_grad() - if args.autoencoder_weight > 0: - bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) - output_ar, output_ae = bs_ar.x, bs_ae.x - loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) - loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1]) - else: - output = model(mygpt.BracketedSequence(input)).x - loss_ar = F.cross_entropy(output.transpose(1, 2), input) - loss_ae = loss_ar.new_full((1,), 0.0) - - acc_train_loss_ar += loss_ar.item() * input.size(0) - acc_train_loss_ae += loss_ae.item() * input.size(0) + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), input) + acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) - (loss_ar + args.autoencoder_weight * loss_ae).backward() + loss.backward() if nb_train_samples % args.batch_size == 0: optimizer.step() @@ -846,46 +837,28 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): with torch.autograd.no_grad(): model.eval() - nb_test_samples, acc_test_loss_ar, acc_test_loss_ae = 0, 0.0, 0.0 + nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 for input in task.batches(split="test"): input = input.to(device) - if args.autoencoder_weight > 0: - bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True) - output_ar, output_ae = bs_ar.x, bs_ae.x - loss_ae = F.cross_entropy( - output_ae[:, 1:].transpose(1, 2), input[:, :-1] - ) - acc_test_loss_ae += loss_ae.item() * input.size(0) - else: - bs_ar = model(mygpt.BracketedSequence(input)) - output_ar = bs_ar.x + bs = model(mygpt.BracketedSequence(input)) + output_ar = bs.x - loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input) + loss = F.cross_entropy(output.transpose(1, 2), input) - acc_test_loss_ar += loss_ar.item() * input.size(0) + acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) - train_ar_perplexity = math.exp(min(100, acc_train_loss_ar / nb_train_samples)) - test_ar_perplexity = math.exp(min(100, acc_test_loss_ar / nb_test_samples)) + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) log_string( - f"perplexity_ar {n_epoch} train_set {train_set_perplexity} train_prediction {train_ar_perplexity} test_prediction {test_ar_perplexity}" + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" ) - if args.autoencoder_weight > 0: - train_ae_perplexity = math.exp( - min(100, acc_train_loss_ae / nb_train_samples) - ) - test_ae_perplexity = math.exp(min(100, acc_test_loss_ae / nb_test_samples)) - - log_string( - f"perplexity_ae {n_epoch} train_set {train_set_perplexity} train_prediction {train_ae_perplexity} test_prediction {test_ae_perplexity}" - ) - task.produce_results( n_epoch=n_epoch, model=model,