X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=f850f699ed9c3d18ef521df99b0a77595485fa7c;hb=e14c55948c1f099f95fa3d7343b5c939e60fcb1c;hp=bdc12aa405624a866de0f9f2eb320f5aedf53210;hpb=13e4bbafc59edd81528ddf8320b58052daec50b8;p=beaver.git diff --git a/beaver.py b/beaver.py index bdc12aa..f850f69 100755 --- a/beaver.py +++ b/beaver.py @@ -64,6 +64,10 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + +parser.add_argument("--noncausal_prompt", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) @@ -129,18 +133,56 @@ for n in vars(args): ###################################################################### +def generation_order(x, fixed_len=0): + if args.random_regression_order: + order = torch.rand(x.size(), device=x.device) + order[:, :fixed_len] = torch.arange(-fixed_len, 0, device=x.device) + order = order.sort(1).indices + else: + order = ( + torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1) + ) + return order + + +def reorder(x, order, reverse=False): # x is NxTxD1x...xDk, order is NxT' + u = x.reshape(x.size()[:2] + (-1,)) + order = order.unsqueeze(-1).expand(-1, -1, u.size(-1)) + if reverse: + v = u.new(u.size()).scatter_(1, order, u) + else: + v = u.gather(1, order) + v = v.reshape(v.size()[:2] + x.size()[2:]) + return v + + +def shuffle(x, fixed_len): + order = generation_order(x, fixed_len) + return reorder(x, order), order + + +def eval_mygpt(model, input, mode="standard", fixed_len=0): + x, order = shuffle(input, fixed_len) + x = model(mygpt.BracketedSequence(x), mode=mode, order=order).x + return reorder(x, order, reverse=True) + + +###################################################################### + # ar_mask is a Boolean matrix of same shape as input, with 1s on the # tokens that should be generated -def masked_inplace_autoregression(model, batch_size, input, ar_mask): - for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): +def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None): + for input, ar_mask, order in zip( + input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size) + ): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: # Needed to initialize the model's cache - model(mygpt.BracketedSequence(input, 0, i.min())) + model(mygpt.BracketedSequence(input, 0, i.min()), order=order) for s in range(i.min(), i.max() + 1): - output = model(mygpt.BracketedSequence(input, s, 1)).x + output = model(mygpt.BracketedSequence(input, s, 1), order=order).x logits = output[:, s] if args.deterministic_synthesis: t_next = logits.argmax(1) @@ -153,7 +195,7 @@ def masked_inplace_autoregression(model, batch_size, input, ar_mask): ###################################################################### -def compute_perplexity(model, split="train"): +def compute_perplexity(model, task, fixed_len, split="train"): with torch.autograd.no_grad(): t = model.training model.eval() @@ -162,9 +204,12 @@ def compute_perplexity(model, split="train"): for input in task.batches(split=split): input = input.to(device) - - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + output = eval_mygpt(model, input, fixed_len=fixed_len) + if args.noncausal_prompt: + t = input.size(1) // 2 + loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:]) + else: + loss = F.cross_entropy(output.transpose(1, 2), input) acc_loss += loss.item() * input.size(0) nb_samples += input.size(0) @@ -227,7 +272,9 @@ def oneshot(gpt, task): acc_train_loss, nb_train_samples = 0, 0 for mazes, policies in task.policy_batches(split="train"): - output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x + output_gpt = eval_mygpt( + gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width + ) output = model(output_gpt) loss = compute_loss(mazes, output, policies, task.height, task.width) @@ -240,7 +287,9 @@ def oneshot(gpt, task): acc_test_loss, nb_test_samples = 0, 0 for mazes, policies in task.policy_batches(split="test"): - output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x + output_gpt = eval_mygpt( + gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width + ) output = model(output_gpt) loss = compute_loss(mazes, output, policies, task.height, task.width) acc_test_loss += loss.item() * mazes.size(0) @@ -253,7 +302,9 @@ def oneshot(gpt, task): # ------------------- mazes = task.test_input[:32, : task.height * task.width] policies = task.test_policies[:32] - output_gpt = gpt(mygpt.BracketedSequence(mazes), mode=args.oneshot_input).x + output_gpt = eval_mygpt( + gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width + ) output = model(output_gpt) if args.oneshot_output == "policy": targets = policies.permute(0, 2, 1) @@ -272,15 +323,17 @@ def oneshot(gpt, task): scores = scores.reshape(-1, task.height, task.width) mazes = mazes.reshape(-1, task.height, task.width) targets = targets.reshape(-1, task.height, task.width) + filename = ( + f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png" + ) maze.save_image( - os.path.join( - args.result_dir, - f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png", - ), + os.path.join(args.result_dir, filename), mazes=mazes, score_paths=scores, score_truth=targets, ) + log_string(f"wrote {filename}") + # ------------------- gpt.train(t) @@ -290,7 +343,7 @@ def oneshot(gpt, task): class Task: - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): pass def vocabulary_size(self): @@ -350,17 +403,19 @@ class TaskMaze(Task): self.nb_codes = self.train_input.max() + 1 - def batches(self, split="train", nb_to_use=-1): + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input if nb_to_use > 0: input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch - def policy_batches(self, split="train", nb_to_use=-1): + def policy_batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input policies = self.train_policies if split == "train" else self.test_policies @@ -371,10 +426,12 @@ class TaskMaze(Task): input = input[:nb_to_use] policies = policies[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( zip(input.split(self.batch_size), policies.split(self.batch_size)), dynamic_ncols=True, - desc=f"epoch-{split}", + desc=desc, ): yield batch @@ -388,7 +445,11 @@ class TaskMaze(Task): ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask - masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + x, order = shuffle(result, self.height * self.width) + masked_inplace_autoregression( + model, self.batch_size, x, ar_mask, order=order + ) + result = reorder(x, order, reverse=True) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() nb_total += mazes.size(0) @@ -419,17 +480,23 @@ class TaskMaze(Task): ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask - masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + x, order = shuffle(result, self.height * self.width) + masked_inplace_autoregression( + model, self.batch_size, x, ar_mask, order=order + ) + result = reorder(x, order, reverse=True) mazes, paths = self.seq2map(input) _, predicted_paths = self.seq2map(result) + filename = f"result_{n_epoch:04d}.png" maze.save_image( - os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"), + os.path.join(args.result_dir, filename), mazes=mazes, target_paths=paths, predicted_paths=predicted_paths, path_correct=maze.path_correctness(mazes, predicted_paths), ) + log_string(f"wrote {filename}") model.train(t) @@ -456,6 +523,17 @@ log_string(f"vocabulary_size {vocabulary_size}") ############################## +amm_generator = None + +if args.noncausal_prompt: + amm_generator = lambda d: torch.logical_and( + torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :], + torch.logical_or( + torch.arange(d)[None, None, :, None] >= d // 2, + torch.arange(d)[None, None, None, :] >= d // 2, + ), + ) + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -465,6 +543,7 @@ model = mygpt.MyGPT( nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, + amm_generator=amm_generator, ) model.to(device) @@ -535,8 +614,12 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") if nb_epochs_finished >= args.nb_epochs: n_epoch = nb_epochs_finished - train_perplexity = compute_perplexity(model, split="train") - test_perplexity = compute_perplexity(model, split="test") + train_perplexity = compute_perplexity( + model, task, fixed_len=task.height * task.width, split="train" + ) + test_perplexity = compute_perplexity( + model, task, fixed_len=task.height * task.width, split="test" + ) log_string( f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" @@ -544,8 +627,6 @@ if nb_epochs_finished >= args.nb_epochs: task.produce_results(n_epoch, model) - exit(0) - ############################## for n_epoch in range(nb_epochs_finished, args.nb_epochs): @@ -568,8 +649,14 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split="train"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x - loss = F.cross_entropy(output.transpose(1, 2), input) + output = eval_mygpt( + model, input, mode=args.oneshot_input, fixed_len=task.height * task.width + ) + if args.noncausal_prompt: + t = input.size(1) // 2 + loss = F.cross_entropy(output[:, t:].transpose(1, 2), input[:, t:]) + else: + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -578,7 +665,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): optimizer.step() train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - test_perplexity = compute_perplexity(model, split="test") + test_perplexity = compute_perplexity( + model, task, fixed_len=task.height * task.width, split="test" + ) log_string( f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"