X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=5ee468e0bf070a79ce1deff3045ab77b43e1becd;hb=9fb6b9a0f1aff25ec1b559c318102060178f3a90;hp=6a6343d4fc3e303f3352bfaaf7a29a5eabf020fc;hpb=88bcf05864ddf89d071ee4be17af57b3b3ce7c2a;p=beaver.git diff --git a/beaver.py b/beaver.py index 6a6343d..5ee468e 100755 --- a/beaver.py +++ b/beaver.py @@ -66,6 +66,8 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--random_regression_order", action="store_true", default=False) +parser.add_argument("--noncausal_prompt", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -517,6 +519,17 @@ log_string(f"vocabulary_size {vocabulary_size}") ############################## +amm_generator = None + +if args.noncausal_prompt: + amm_generator = lambda d: torch.logical_and( + torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :], + torch.logical_or( + torch.arange(d)[None, None, :, None] >= d // 2, + torch.arange(d)[None, None, None, :] >= d // 2, + ), + ) + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -526,6 +539,7 @@ model = mygpt.MyGPT( nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, + amm_generator=amm_generator, ) model.to(device)