Update
authorFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 18:32:12 +0000 (19:32 +0100)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 12 Mar 2023 18:32:12 +0000 (19:32 +0100)
beaver.py
maze.py
mygpt.py
tensorstack.py

index 517f29a..4f694da 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -68,6 +68,8 @@ parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
 parser.add_argument("--overwrite_results", action="store_true", default=False)
 
+parser.add_argument("--one_shot", action="store_true", default=False)
+
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
 ##############################
@@ -125,7 +127,6 @@ for n in vars(args):
 
 
 def masked_inplace_autoregression(model, batch_size, input, ar_mask):
-
     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
@@ -145,6 +146,36 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask):
 ######################################################################
 
 
+def compute_perplexity(model, split="train"):
+    with torch.autograd.no_grad():
+        t = model.training
+        model.eval()
+
+        nb_samples, acc_loss = 0, 0.0
+
+        for input in task.batches(split=split):
+            input = input.to(device)
+
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), input)
+            acc_loss += loss.item() * input.size(0)
+            nb_samples += input.size(0)
+
+        model.train(t)
+
+        return math.exp(min(100, acc_loss / nb_samples))
+
+
+######################################################################
+
+
+def one_shot(gpt, task):
+    pass
+
+
+######################################################################
+
+
 class Task:
     def batches(self, split="train"):
         pass
@@ -373,13 +404,28 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 ##############################
 
-nb_samples_seen = 0
+if args.one_shot:
+    one_shot(model, task)
+    exit(0)
+
+##############################
 
 if nb_epochs_finished >= nb_epochs:
-    task.produce_results(nb_epochs_finished, model)
+    n_epoch = nb_epochs_finished
+    train_perplexity = compute_perplexity(model, split="train")
+    test_perplexity = compute_perplexity(model, split="test")
 
-for n_epoch in range(nb_epochs_finished, nb_epochs):
+    log_string(
+        f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+    )
+
+    task.produce_results(n_epoch, model)
 
+    exit(0)
+
+##############################
+
+for n_epoch in range(nb_epochs_finished, nb_epochs):
     learning_rate = learning_rate_schedule[n_epoch]
 
     log_string(f"learning_rate {learning_rate}")
@@ -403,34 +449,19 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
-        nb_samples_seen += input.size(0)
 
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
 
-    with torch.autograd.no_grad():
-
-        model.eval()
-
-        nb_test_samples, acc_test_loss = 0, 0.0
-
-        for input in task.batches(split="test"):
-            input = input.to(device)
-
-            output = model(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+    test_perplexity = compute_perplexity(model, split="test")
 
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
-        log_string(
-            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
-        )
+    log_string(
+        f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+    )
 
-        task.produce_results(n_epoch, model)
+    task.produce_results(n_epoch, model)
 
     checkpoint = {
         "nb_epochs_finished": n_epoch + 1,
diff --git a/maze.py b/maze.py
index e377d2f..cfdede3 100755 (executable)
--- a/maze.py
+++ b/maze.py
@@ -241,7 +241,6 @@ def save_image(name, mazes, target_paths, predicted_paths=None, path_correct=Non
 ######################################################################
 
 if __name__ == "__main__":
-
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     mazes, paths = create_maze_data(8)
     mazes, paths = mazes.to(device), paths.to(device)
index df6eab6..a0f3dbf 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -197,7 +197,6 @@ class MyGPT(nn.Module):
         dropout=0.0,
         len_max=1e5,
     ):
-
         super().__init__()
 
         assert dim_model % nb_heads == 0
@@ -258,7 +257,6 @@ class MyGPT(nn.Module):
 ######################################################################
 
 if __name__ == "__main__":
-
     print("Basic check.")
 
     vocabulary_size = 10
index 3218be1..584c12d 100755 (executable)
@@ -45,7 +45,6 @@ sys.excepthook = exception_hook
 ######################################################################
 
 if __name__ == "__main__":
-
     import torch
 
     def dummy(a, b):