parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--random_regression_order", action="store_true", default=False)
+
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--overwrite_results", action="store_true", default=False)
def random_order(result, fixed_len):
- order = torch.rand(result.size(), device=result.device)
- order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
- return order.sort(1).indices
+ if args.random_regression_order:
+ order = torch.rand(result.size(), device=result.device)
+ order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
+ return order.sort(1).indices
+ else:
+ return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1)
def shuffle(x, order, reorder=False):
acc_train_loss, nb_train_samples = 0, 0
for mazes, policies in task.policy_batches(split="train"):
- order = random_order(input, task.height * task.width)
+ order = random_order(mazes, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
acc_test_loss, nb_test_samples = 0, 0
for mazes, policies in task.policy_batches(split="test"):
- order = random_order(input, task.height * task.width)
+ order = random_order(mazes, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
# -------------------
mazes = task.test_input[:32, : task.height * task.width]
policies = task.test_policies[:32]
- order = random_order(input, task.height * task.width)
+ order = random_order(mazes, task.height * task.width)
x = shuffle(mazes, order)
x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
output_gpt = shuffle(x, order, reorder=True)
task.produce_results(n_epoch, model)
- exit(0)
-
##############################
for n_epoch in range(nb_epochs_finished, args.nb_epochs):