parser.add_argument("--random_regression_order", action="store_true", default=False)
+parser.add_argument("--noncausal_prompt", action="store_true", default=False)
+
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--overwrite_results", action="store_true", default=False)
##############################
+amm_generator = None
+
+if args.noncausal_prompt:
+ amm_generator = lambda d: torch.logical_and(
+ torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
+ torch.arange(d)[None, None, :, None] >= d // 2,
+ )
+
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
+ amm_generator=amm_generator,
)
model.to(device)