From 10c1e2159ef57a55724fb1753381dc30e8aa77c2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 22 Mar 2023 20:16:29 +0100 Subject: [PATCH] Update --- beaver.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/beaver.py b/beaver.py index 5a15aee..49cb1f6 100755 --- a/beaver.py +++ b/beaver.py @@ -64,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -130,9 +132,12 @@ for n in vars(args): def random_order(result, fixed_len): - order = torch.rand(result.size(), device=result.device) - order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) - return order.sort(1).indices + if args.random_regression_order: + order = torch.rand(result.size(), device=result.device) + order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) + return order.sort(1).indices + else: + return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) def shuffle(x, order, reorder=False): @@ -579,8 +584,6 @@ if nb_epochs_finished >= args.nb_epochs: task.produce_results(n_epoch, model) - exit(0) - ############################## for n_epoch in range(nb_epochs_finished, args.nb_epochs): -- 2.39.5