Update
[beaver.git] / beaver.py
index f5bd924..bd17365 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -64,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
+parser.add_argument("--random_regression_order", action="store_true", default=False)
+
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
 parser.add_argument("--overwrite_results", action="store_true", default=False)
@@ -86,7 +88,7 @@ parser.add_argument("--oneshot", action="store_true", default=False)
 
 parser.add_argument("--oneshot_input", type=str, default="head")
 
-parser.add_argument("--oneshot_output", type=str, default="policy")
+parser.add_argument("--oneshot_output", type=str, default="trace")
 
 ######################################################################
 
@@ -129,18 +131,49 @@ for n in vars(args):
 ######################################################################
 
 
+def generation_order(x, fixed_len):
+    if args.random_regression_order:
+        order = torch.rand(x.size(), device=x.device)
+        order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+        order = order.sort(1).indices
+    else:
+        order = (
+            torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
+        )
+    return order
+
+
+def reorder(x, order, back=False):  # x is NxTxD1x...xDk, order is NxT'
+    u = x.reshape(x.size()[:2] + (-1,))
+    order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
+    if back:
+        v = u.new(u.size())
+        v.scatter_(1, order, u)
+    else:
+        v = u.gather(1, order)
+    v = v.reshape(v.size()[:2] + x.size()[2:])
+    return v
+
+
+def shuffle(x, fixed_len):
+    order = generation_order(x, fixed_len)
+    return reorder(x, order), order
+
+
+######################################################################
+
 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
 # tokens that should be generated
 
 
-def masked_inplace_autoregression(model, batch_size, input, ar_mask):
+def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             # Needed to initialize the model's cache
-            model(mygpt.BracketedSequence(input, 0, i.min()))
+            model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
         for s in range(i.min(), i.max() + 1):
-            output = model(mygpt.BracketedSequence(input, s, 1)).x
+            output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
             logits = output[:, s]
             if args.deterministic_synthesis:
                 t_next = logits.argmax(1)
@@ -153,7 +186,7 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask):
 ######################################################################
 
 
-def compute_perplexity(model, split="train"):
+def compute_perplexity(model, fixed_len, split="train"):
     with torch.autograd.no_grad():
         t = model.training
         model.eval()
@@ -162,8 +195,9 @@ def compute_perplexity(model, split="train"):
 
         for input in task.batches(split=split):
             input = input.to(device)
-
-            output = model(mygpt.BracketedSequence(input)).x
+            x, order = shuffle(input, fixed_len)
+            x = model(mygpt.BracketedSequence(x), order=order).x
+            output = reorder(x, order, back=True)
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_loss += loss.item() * input.size(0)
             nb_samples += input.size(0)
@@ -227,7 +261,9 @@ def oneshot(gpt, task):
 
         acc_train_loss, nb_train_samples = 0, 0
         for mazes, policies in task.policy_batches(split="train"):
-            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+            x, order = shuffle(mazes, task.height * task.width)
+            x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+            output_gpt = reorder(x, order, back=True)
             output = model(output_gpt)
 
             loss = compute_loss(mazes, output, policies, task.height, task.width)
@@ -240,7 +276,9 @@ def oneshot(gpt, task):
 
         acc_test_loss, nb_test_samples = 0, 0
         for mazes, policies in task.policy_batches(split="test"):
-            output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+            x, order = shuffle(mazes, task.height * task.width)
+            x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+            output_gpt = reorder(x, order, back=True)
             output = model(output_gpt)
             loss = compute_loss(mazes, output, policies, task.height, task.width)
             acc_test_loss += loss.item() * mazes.size(0)
@@ -253,7 +291,9 @@ def oneshot(gpt, task):
         # -------------------
         mazes = task.test_input[:32, : task.height * task.width]
         policies = task.test_policies[:32]
-        output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x
+        x, order = shuffle(mazes, task.height * task.width)
+        x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
+        output_gpt = reorder(x, order, back=True)
         output = model(output_gpt)
         if args.oneshot_output == "policy":
             targets = policies.permute(0, 2, 1)
@@ -272,15 +312,17 @@ def oneshot(gpt, task):
         scores = scores.reshape(-1, task.height, task.width)
         mazes = mazes.reshape(-1, task.height, task.width)
         targets = targets.reshape(-1, task.height, task.width)
+        filename = (
+            f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
+        )
         maze.save_image(
-            os.path.join(
-                args.result_dir,
-                f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
-            ),
+            os.path.join(args.result_dir, filename),
             mazes=mazes,
             score_paths=scores,
             score_truth=targets,
         )
+        log_string(f"wrote {filename}")
+
         # -------------------
 
     gpt.train(t)
@@ -290,7 +332,7 @@ def oneshot(gpt, task):
 
 
 class Task:
-    def batches(self, split="train"):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         pass
 
     def vocabulary_size(self):
@@ -350,17 +392,19 @@ class TaskMaze(Task):
 
         self.nb_codes = self.train_input.max() + 1
 
-    def batches(self, split="train", nb_to_use=-1):
+    def batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         if nb_to_use > 0:
             input = input[:nb_to_use]
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+            input.split(self.batch_size), dynamic_ncols=True, desc=desc
         ):
             yield batch
 
-    def policy_batches(self, split="train", nb_to_use=-1):
+    def policy_batches(self, split="train", nb_to_use=-1, desc=None):
         assert split in {"train", "test"}
         input = self.train_input if split == "train" else self.test_input
         policies = self.train_policies if split == "train" else self.test_policies
@@ -371,10 +415,12 @@ class TaskMaze(Task):
             input = input[:nb_to_use]
             policies = policies[:nb_to_use]
 
+        if desc is None:
+            desc = f"epoch-{split}"
         for batch in tqdm.tqdm(
             zip(input.split(self.batch_size), policies.split(self.batch_size)),
             dynamic_ncols=True,
-            desc=f"epoch-{split}",
+            desc=desc,
         ):
             yield batch
 
@@ -388,7 +434,11 @@ class TaskMaze(Task):
             ar_mask = result.new_zeros(result.size())
             ar_mask[:, self.height * self.width :] = 1
             result *= 1 - ar_mask
-            masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
+            x, order = shuffle(result, self.height * self.width)
+            masked_inplace_autoregression(
+                model, self.batch_size, x, ar_mask, order=order
+            )
+            result = reorder(x, order, back=True)
             mazes, paths = self.seq2map(result)
             nb_correct += maze.path_correctness(mazes, paths).long().sum()
             nb_total += mazes.size(0)
@@ -423,13 +473,15 @@ class TaskMaze(Task):
 
             mazes, paths = self.seq2map(input)
             _, predicted_paths = self.seq2map(result)
+            filename = f"result_{n_epoch:04d}.png"
             maze.save_image(
-                os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
+                os.path.join(args.result_dir, filename),
                 mazes=mazes,
                 target_paths=paths,
                 predicted_paths=predicted_paths,
                 path_correct=maze.path_correctness(mazes, predicted_paths),
             )
+            log_string(f"wrote {filename}")
 
             model.train(t)
 
@@ -535,8 +587,12 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}")
 
 if nb_epochs_finished >= args.nb_epochs:
     n_epoch = nb_epochs_finished
-    train_perplexity = compute_perplexity(model, split="train")
-    test_perplexity = compute_perplexity(model, split="test")
+    train_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="train"
+    )
+    test_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="test"
+    )
 
     log_string(
         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
@@ -544,8 +600,6 @@ if nb_epochs_finished >= args.nb_epochs:
 
     task.produce_results(n_epoch, model)
 
-    exit(0)
-
 ##############################
 
 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
@@ -568,7 +622,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
 
     for input in task.batches(split="train"):
         input = input.to(device)
-        output = model(mygpt.BracketedSequence(input)).x
+        x, order = shuffle(input, task.height * task.width)
+        x = model(mygpt.BracketedSequence(x), order=order).x
+        output = reorder(x, order, back=True)
         loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
@@ -578,7 +634,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         optimizer.step()
 
     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-    test_perplexity = compute_perplexity(model, split="test")
+    test_perplexity = compute_perplexity(
+        model, fixed_len=task.height * task.width, split="test"
+    )
 
     log_string(
         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"