X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=attentiontoy1d.py;h=92d90cf79bb62ac618b1bd7590ee8f8fb498cba8;hb=437a0746551145f241b39d4a95ae28ecd1410a54;hp=d7f06fe0b587ba8f08dbfdda93ca58728a955f84;hpb=4d0e56bee81c535293367628dd73cbf993d0690a;p=pytorch.git diff --git a/attentiontoy1d.py b/attentiontoy1d.py index d7f06fe..92d90cf 100755 --- a/attentiontoy1d.py +++ b/attentiontoy1d.py @@ -31,8 +31,15 @@ parser.add_argument('--positional_encoding', help = 'Provide a positional encoding', action='store_true', default=False) +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + args = parser.parse_args() +if args.seed >= 0: + torch.manual_seed(args.seed) + ###################################################################### label='' @@ -62,8 +69,6 @@ if torch.cuda.is_available(): else: device = torch.device('cpu') -torch.manual_seed(1) - ###################################################################### seq_height_min, seq_height_max = 1.0, 25.0