X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=49cb1f672075795f84021e9e8a5d6573a6d6f56e;hb=10c1e2159ef57a55724fb1753381dc30e8aa77c2;hp=6ec0fb290e2109077b6aefe1a2ae63d032e755b2;hpb=2cd3f15987d2bf9050f737cd13506740ad3e90cb;p=beaver.git diff --git a/beaver.py b/beaver.py index 6ec0fb2..49cb1f6 100755 --- a/beaver.py +++ b/beaver.py @@ -64,6 +64,8 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -130,9 +132,12 @@ for n in vars(args): def random_order(result, fixed_len): - order = torch.rand(result.size(), device=result.device) - order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) - return order.sort(1).indices + if args.random_regression_order: + order = torch.rand(result.size(), device=result.device) + order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) + return order.sort(1).indices + else: + return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) def shuffle(x, order, reorder=False): @@ -245,7 +250,7 @@ def oneshot(gpt, task): acc_train_loss, nb_train_samples = 0, 0 for mazes, policies in task.policy_batches(split="train"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -261,7 +266,7 @@ def oneshot(gpt, task): acc_test_loss, nb_test_samples = 0, 0 for mazes, policies in task.policy_batches(split="test"): - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -277,7 +282,7 @@ def oneshot(gpt, task): # ------------------- mazes = task.test_input[:32, : task.height * task.width] policies = task.test_policies[:32] - order = random_order(input, task.height * task.width) + order = random_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -579,8 +584,6 @@ if nb_epochs_finished >= args.nb_epochs: task.produce_results(n_epoch, model) - exit(0) - ############################## for n_epoch in range(nb_epochs_finished, args.nb_epochs):