From 539b475100e792e284d030e2a0b4bdb41c0ff780 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 22 Mar 2023 23:10:53 +0100 Subject: [PATCH] Update --- beaver.py | 20 +++++++++++--------- mygpt.py | 18 +++++++----------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/beaver.py b/beaver.py index 49cb1f6..8fe9a9b 100755 --- a/beaver.py +++ b/beaver.py @@ -131,13 +131,15 @@ for n in vars(args): ###################################################################### -def random_order(result, fixed_len): +def generation_order(x, fixed_len): if args.random_regression_order: - order = torch.rand(result.size(), device=result.device) + order = torch.rand(x.size(), device=x.device) order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) return order.sort(1).indices else: - return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) + return ( + torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) + ) def shuffle(x, order, reorder=False): @@ -184,7 +186,7 @@ def compute_perplexity(model, split="train"): for input in task.batches(split=split): input = input.to(device) - order = random_order(input, task.height * task.width) + order = generation_order(input, task.height * task.width) input = shuffle(input, order) output = model(mygpt.BracketedSequence(input), order=order).x loss = F.cross_entropy(output.transpose(1, 2), input) @@ -250,7 +252,7 @@ def oneshot(gpt, task): acc_train_loss, nb_train_samples = 0, 0 for mazes, policies in task.policy_batches(split="train"): - order = random_order(mazes, task.height * task.width) + order = generation_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -266,7 +268,7 @@ def oneshot(gpt, task): acc_test_loss, nb_test_samples = 0, 0 for mazes, policies in task.policy_batches(split="test"): - order = random_order(mazes, task.height * task.width) + order = generation_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -282,7 +284,7 @@ def oneshot(gpt, task): # ------------------- mazes = task.test_input[:32, : task.height * task.width] policies = task.test_policies[:32] - order = random_order(mazes, task.height * task.width) + order = generation_order(mazes, task.height * task.width) x = shuffle(mazes, order) x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x output_gpt = shuffle(x, order, reorder=True) @@ -424,7 +426,7 @@ class TaskMaze(Task): ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask - order = random_order(result, self.height * self.width) + order = generation_order(result, self.height * self.width) masked_inplace_autoregression( model, self.batch_size, result, ar_mask, order=order ) @@ -606,7 +608,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split="train"): input = input.to(device) - order = random_order(input, task.height * task.width) + order = generation_order(input, task.height * task.width) input = shuffle(input, order) output = model(mygpt.BracketedSequence(input), order=order).x loss = F.cross_entropy(output.transpose(1, 2), input) diff --git a/mygpt.py b/mygpt.py index 311ff6b..75adbf6 100755 --- a/mygpt.py +++ b/mygpt.py @@ -106,20 +106,16 @@ class AddPositionalEncoding(nn.Module): ) order_output = order + 1 - order_input = torch.cat( - (order.new_zeros(order.size(0), 1), order[:, :-1] + 1), 1 - ) + order_input = F.pad(order + 1, (1, -1)) - self.pe = torch.cat( - ( - pe.gather(1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1))), - pe.gather( - 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) - ), - ), - 2, + pe_input = pe.gather( + 1, order_input.unsqueeze(-1).expand(-1, -1, pe.size(-1)) + ) + pe_output = pe.gather( + 1, order_output.unsqueeze(-1).expand(-1, -1, pe.size(-1)) ) + self.pe = torch.cat((pe_input, pe_output), 2) self.cache_y = bs.x.new(bs.x.size()) self.cache_y[:, bs.first : bs.first + bs.nb] = ( -- 2.20.1