From: François Fleuret Date: Wed, 22 Mar 2023 19:16:29 +0000 (+0100) Subject: Update X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=10c1e2159ef57a55724fb1753381dc30e8aa77c2;p=beaver.git Update --- diff --git a/beaver.py b/beaver.py index 5a15aee..49cb1f6 100755 --- a/beaver.py +++ b/beaver.py @@ -64,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -130,9 +132,12 @@ for n in vars(args): def random_order(result, fixed_len): - order = torch.rand(result.size(), device=result.device) - order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) - return order.sort(1).indices + if args.random_regression_order: + order = torch.rand(result.size(), device=result.device) + order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) + return order.sort(1).indices + else: + return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) def shuffle(x, order, reorder=False): @@ -579,8 +584,6 @@ if nb_epochs_finished >= args.nb_epochs: task.produce_results(n_epoch, model) - exit(0) - ############################## for n_epoch in range(nb_epochs_finished, args.nb_epochs):