Update
authorFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 08:43:14 +0000 (09:43 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 23 Mar 2023 08:43:14 +0000 (09:43 +0100)
beaver.py

index 69116ea..dca97cc 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -186,7 +186,7 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None)
 ######################################################################
 
 
-def compute_perplexity(model, split="train"):
+def compute_perplexity(model, fixed_len, split="train"):
     with torch.autograd.no_grad():
         t = model.training
         model.eval()
@@ -195,8 +195,9 @@ def compute_perplexity(model, split="train"):
 
         for input in task.batches(split=split):
             input = input.to(device)
-            input, order = shuffle(input, task.height * task.width)
-            output = model(mygpt.BracketedSequence(input), order=order).x
+            x, order = shuffle(input, fixed_len)
+            x = model(mygpt.BracketedSequence(x), order=order).x
+            output = reorder(x, order, back=True)
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_loss += loss.item() * input.size(0)
             nb_samples += input.size(0)
@@ -431,11 +432,11 @@ class TaskMaze(Task):
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
-            result, order = shuffle(result, self.height * self.width)
+            x, order = shuffle(result, self.height * self.width)
             masked_inplace_autoregression(
-                model, self.batch_size, result, ar_mask, order=order
+                model, self.batch_size, x, ar_mask, order=order
             )
-            result = reorder(result, order, back=True)
+            result = reorder(x, order, back=True)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
             nb_total += mazes.size(0)
@@ -582,8 +583,12 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 if nb_epochs_finished >= args.nb_epochs:
     n_epoch = nb_epochs_finished
-    train_perplexity = compute_perplexity(model, split="train")
-    test_perplexity = compute_perplexity(model, split="test")
+    train_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="train"
+    )
+    test_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="test"
+    )
 
     log_string(
         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
@@ -613,8 +618,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     for input in task.batches(split="train"):
         input = input.to(device)
-        input, order = shuffle(input, task.height * task.width)
-        output = model(mygpt.BracketedSequence(input), order=order).x
+        x, order = shuffle(input, task.height * task.width)
+        x = model(mygpt.BracketedSequence(x), order=order).x
+        output = reorder(x, order, back=True)
         loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
@@ -624,7 +630,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         optimizer.step()
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-    test_perplexity = compute_perplexity(model, split="test")
+    test_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="test"
+    )
 
     log_string(
         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"