Update
[beaver.git] / beaver.py
index 6a6343d..f850f69 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -66,6 +66,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--random_regression_order", action="store_true", default=False)
 
+parser.add_argument("--noncausal_prompt", action="store_true", default=False)
+
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
 parser.add_argument("--overwrite_results", action="store_true", default=False)
@@ -203,7 +205,11 @@ def compute_perplexity(model, task, fixed_len, split="train"):
         for input in task.batches(split=split):
             input = input.to(device)
             output = eval_mygpt(model, input, fixed_len=fixed_len)
-            loss = F.cross_entropy(output.transpose(1, 2), input)
+            if args.noncausal_prompt:
+                t = input.size(1) // 2
+                loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:])
+            else:
+                loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_loss += loss.item() * input.size(0)
             nb_samples += input.size(0)
 
@@ -517,6 +523,17 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
 
+amm_generator = None
+
+if args.noncausal_prompt:
+    amm_generator = lambda d: torch.logical_and(
+        torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
+        torch.logical_or(
+            torch.arange(d)[None, None, :, None] >= d // 2,
+            torch.arange(d)[None, None, None, :] >= d // 2,
+        ),
+    )
+
 model = mygpt.MyGPT(
     vocabulary_size=vocabulary_size,
     dim_model=args.dim_model,
@@ -526,6 +543,7 @@ model = mygpt.MyGPT(
     nb_blocks=args.nb_blocks,
     causal=True,
     dropout=args.dropout,
+    amm_generator=amm_generator,
 )
 
 model.to(device)
@@ -634,7 +652,11 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs):
         output = eval_mygpt(
             model, input, mode=args.oneshot_input, fixed_len=task.height * task.width
         )
-        loss = F.cross_entropy(output.transpose(1, 2), input)
+        if args.noncausal_prompt:
+            t = input.size(1) // 2
+            loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:])
+        else:
+            loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)