Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 13 Jun 2024 17:42:51 +0000 (19:42 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 13 Jun 2024 17:42:51 +0000 (19:42 +0200)
main.py

diff --git a/main.py b/main.py
index dace5f2..37515b5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -815,7 +815,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     model.train()
 
-    nb_train_samples, acc_train_loss_ar, acc_train_loss_ae = 0, 0.0, 0.0
+    nb_train_samples, acc_train_loss = 0, 0.0
 
     for input in task.batches(split="train"):
         input = input.to(device)
@@ -823,22 +823,13 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         if nb_train_samples % args.batch_size == 0:
             optimizer.zero_grad()
 
-        if args.autoencoder_weight > 0:
-            bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
-            output_ar, output_ae = bs_ar.x, bs_ae.x
-            loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
-            loss_ae = F.cross_entropy(output_ae[:, 1:].transpose(1, 2), input[:, :-1])
-        else:
-            output = model(mygpt.BracketedSequence(input)).x
-            loss_ar = F.cross_entropy(output.transpose(1, 2), input)
-            loss_ae = loss_ar.new_full((1,), 0.0)
-
-        acc_train_loss_ar += loss_ar.item() * input.size(0)
-        acc_train_loss_ae += loss_ae.item() * input.size(0)
+        output = model(mygpt.BracketedSequence(input)).x
+        loss = F.cross_entropy(output.transpose(1, 2), input)
+        acc_train_loss += loss.item() * input.size(0)
 
         nb_train_samples += input.size(0)
 
-        (loss_ar + args.autoencoder_weight * loss_ae).backward()
+        loss.backward()
 
         if nb_train_samples % args.batch_size == 0:
             optimizer.step()
@@ -846,46 +837,28 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
     with torch.autograd.no_grad():
         model.eval()
 
-        nb_test_samples, acc_test_loss_ar, acc_test_loss_ae = 0, 0.0, 0.0
+        nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            if args.autoencoder_weight > 0:
-                bs_ar, bs_ae = model(mygpt.BracketedSequence(input), autoencoder=True)
-                output_ar, output_ae = bs_ar.x, bs_ae.x
-                loss_ae = F.cross_entropy(
-                    output_ae[:, 1:].transpose(1, 2), input[:, :-1]
-                )
-                acc_test_loss_ae += loss_ae.item() * input.size(0)
-            else:
-                bs_ar = model(mygpt.BracketedSequence(input))
-                output_ar = bs_ar.x
+            bs = model(mygpt.BracketedSequence(input))
+            output_ar = bs.x
 
-            loss_ar = F.cross_entropy(output_ar.transpose(1, 2), input)
+            loss = F.cross_entropy(output.transpose(1, 2), input)
 
-            acc_test_loss_ar += loss_ar.item() * input.size(0)
+            acc_test_loss += loss.item() * input.size(0)
 
             nb_test_samples += input.size(0)
 
-        train_ar_perplexity = math.exp(min(100, acc_train_loss_ar / nb_train_samples))
-        test_ar_perplexity = math.exp(min(100, acc_test_loss_ar / nb_test_samples))
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
 
         log_string(
-            f"perplexity_ar {n_epoch} train_set {train_set_perplexity} train_prediction {train_ar_perplexity} test_prediction {test_ar_perplexity}"
+            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
         )
 
-        if args.autoencoder_weight > 0:
-            train_ae_perplexity = math.exp(
-                min(100, acc_train_loss_ae / nb_train_samples)
-            )
-            test_ae_perplexity = math.exp(min(100, acc_test_loss_ae / nb_test_samples))
-
-            log_string(
-                f"perplexity_ae {n_epoch} train_set {train_set_perplexity} train_prediction {train_ae_perplexity} test_prediction {test_ae_perplexity}"
-            )
-
         task.produce_results(
             n_epoch=n_epoch,
             model=model,