X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=beaver.py;h=49cb1f672075795f84021e9e8a5d6573a6d6f56e;hb=10c1e2159ef57a55724fb1753381dc30e8aa77c2;hp=afec61d4a506161e0da2e449d2dfa3445e386110;hpb=71a5d04a1decec9d71be93cb816a15a8c0de83a2;p=beaver.git diff --git a/beaver.py b/beaver.py index afec61d..49cb1f6 100755 --- a/beaver.py +++ b/beaver.py @@ -64,12 +64,12 @@ parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--deterministic_synthesis", action="store_true", default=False) +parser.add_argument("--random_regression_order", action="store_true", default=False) + parser.add_argument("--no_checkpoint", action="store_true", default=False) parser.add_argument("--overwrite_results", action="store_true", default=False) -parser.add_argument("--one_shot", action="store_true", default=False) - parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## @@ -81,6 +81,15 @@ parser.add_argument("--maze_width", type=int, default=21) parser.add_argument("--maze_nb_walls", type=int, default=15) +############################## +# one-shot prediction + +parser.add_argument("--oneshot", action="store_true", default=False) + +parser.add_argument("--oneshot_input", type=str, default="head") + +parser.add_argument("--oneshot_output", type=str, default="trace") + ###################################################################### args = parser.parse_args() @@ -122,18 +131,38 @@ for n in vars(args): ###################################################################### +def random_order(result, fixed_len): + if args.random_regression_order: + order = torch.rand(result.size(), device=result.device) + order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device) + return order.sort(1).indices + else: + return torch.arange(result.size(1)).unsqueeze(0).expand(result.size(0), -1) + + +def shuffle(x, order, reorder=False): + if x.dim() == 3: + order = order.unsqueeze(-1).expand(-1, -1, x.size(-1)) + if reorder: + y = x.new(x.size()) + y.scatter_(1, order, x) + return y + else: + return x.gather(1, order) + + # ar_mask is a Boolean matrix of same shape as input, with 1s on the # tokens that should be generated -def masked_inplace_autoregression(model, batch_size, input, ar_mask): +def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None): for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)): i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: # Needed to initialize the model's cache - model(mygpt.BracketedSequence(input, 0, i.min())) + model(mygpt.BracketedSequence(input, 0, i.min()), order=order) for s in range(i.min(), i.max() + 1): - output = model(mygpt.BracketedSequence(input, s, 1)).x + output = model(mygpt.BracketedSequence(input, s, 1), order=order).x logits = output[:, s] if args.deterministic_synthesis: t_next = logits.argmax(1) @@ -155,8 +184,9 @@ def compute_perplexity(model, split="train"): for input in task.batches(split=split): input = input.to(device) - - output = model(mygpt.BracketedSequence(input)).x + order = random_order(input, task.height * task.width) + input = shuffle(input, order) + output = model(mygpt.BracketedSequence(input), order=order).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_loss += loss.item() * input.size(0) nb_samples += input.size(0) @@ -169,16 +199,49 @@ def compute_perplexity(model, split="train"): ###################################################################### -def one_shot(gpt, task): +def oneshot_policy_loss(mazes, output, policies, height, width): + masks = (mazes == maze.v_empty).unsqueeze(-1) + targets = policies.permute(0, 2, 1) * masks + output = output * masks + return -(output.log_softmax(-1) * targets).sum() / masks.sum() + + +def oneshot_trace_loss(mazes, output, policies, height, width): + masks = mazes == maze.v_empty + targets = maze.stationary_densities( + mazes.view(-1, height, width), policies.view(-1, 4, height, width) + ).flatten(-2) + targets = targets * masks + output = output.squeeze(-1) * masks + return (output - targets).abs().sum() / masks.sum() + + +def oneshot(gpt, task): t = gpt.training gpt.eval() + if args.oneshot_input == "head": + dim_in = args.dim_model + elif args.oneshot_input == "deep": + dim_in = args.dim_model * args.nb_blocks * 2 + else: + raise ValueError(f"{args.oneshot_input=}") + + if args.oneshot_output == "policy": + dim_out = 4 + compute_loss = oneshot_policy_loss + elif args.oneshot_output == "trace": + dim_out = 1 + compute_loss = oneshot_trace_loss + else: + raise ValueError(f"{args.oneshot_output=}") + model = nn.Sequential( - nn.Linear(args.dim_model, args.dim_model), + nn.Linear(dim_in, args.dim_model), nn.ReLU(), nn.Linear(args.dim_model, args.dim_model), nn.ReLU(), - nn.Linear(args.dim_model, 4), + nn.Linear(args.dim_model, dim_out), ).to(device) for n_epoch in range(args.nb_epochs): @@ -186,60 +249,69 @@ def one_shot(gpt, task): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 - for input, targets in task.policy_batches(split="train"): - output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x + for mazes, policies in task.policy_batches(split="train"): + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) output = model(output_gpt) - targets = targets * (input.unsqueeze(-1) == maze.v_empty) - output = output * (input.unsqueeze(-1) == maze.v_empty) - # loss = (output.softmax(-1) - targets).abs().max(-1).values.mean() - loss = ( - -(output.log_softmax(-1) * targets).sum() - / (input == maze.v_empty).sum() - ) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) + + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_train_loss += loss.item() * mazes.size(0) + nb_train_samples += mazes.size(0) optimizer.zero_grad() loss.backward() optimizer.step() acc_test_loss, nb_test_samples = 0, 0 - for input, targets in task.policy_batches(split="test"): - output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x + for mazes, policies in task.policy_batches(split="test"): + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) output = model(output_gpt) - targets = targets * (input.unsqueeze(-1) == maze.v_empty) - output = output * (input.unsqueeze(-1) == maze.v_empty) - # loss = (output.softmax(-1) - targets).abs().max(-1).values.mean() - loss = ( - -(output.log_softmax(-1) * targets).sum() - / (input == maze.v_empty).sum() - ) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) + loss = compute_loss(mazes, output, policies, task.height, task.width) + acc_test_loss += loss.item() * mazes.size(0) + nb_test_samples += mazes.size(0) log_string( f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}" ) # ------------------- - input = task.test_input[:32, : task.height * task.width] - targets = task.test_policies[:32] - output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x + mazes = task.test_input[:32, : task.height * task.width] + policies = task.test_policies[:32] + order = random_order(mazes, task.height * task.width) + x = shuffle(mazes, order) + x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x + output_gpt = shuffle(x, order, reorder=True) output = model(output_gpt) - # losses = (-output.log_softmax(-1) * targets + targets.xlogy(targets)).sum(-1) - # losses = losses * (input == maze.v_empty) - # losses = losses / losses.max() - # losses = (output.softmax(-1) - targets).abs().max(-1).values - # losses = (losses >= 0.05).float() - losses = ( - (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0 - ).float() - losses = losses.reshape(-1, args.maze_height, args.maze_width) - input = input.reshape(-1, args.maze_height, args.maze_width) + if args.oneshot_output == "policy": + targets = policies.permute(0, 2, 1) + scores = ( + (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0 + ).float() + elif args.oneshot_output == "trace": + targets = maze.stationary_densities( + mazes.view(-1, task.height, task.width), + policies.view(-1, 4, task.height, task.width), + ).flatten(-2) + scores = output + else: + raise ValueError(f"{args.oneshot_output=}") + + scores = scores.reshape(-1, task.height, task.width) + mazes = mazes.reshape(-1, task.height, task.width) + targets = targets.reshape(-1, task.height, task.width) maze.save_image( - os.path.join(args.result_dir, f"oneshot_{n_epoch:04d}.png"), - mazes=input, - score_paths=losses, + os.path.join( + args.result_dir, + f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png", + ), + mazes=mazes, + score_paths=scores, + score_truth=targets, ) # ------------------- @@ -250,7 +322,7 @@ def one_shot(gpt, task): class Task: - def batches(self, split="train"): + def batches(self, split="train", nb_to_use=-1, desc=None): pass def vocabulary_size(self): @@ -296,7 +368,7 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) - self.train_policies = train_policies.flatten(-2).permute(0, 2, 1).to(device) + self.train_policies = train_policies.flatten(-2).to(device) test_mazes, test_paths, test_policies = maze.create_maze_data( nb_test_samples, @@ -306,35 +378,39 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.test_policies = test_policies.flatten(-2).permute(0, 2, 1).to(device) + self.test_policies = test_policies.flatten(-2).to(device) self.nb_codes = self.train_input.max() + 1 - def batches(self, split="train", nb_to_use=-1): + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input if nb_to_use > 0: input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}" + input.split(self.batch_size), dynamic_ncols=True, desc=desc ): yield batch - def policy_batches(self, split="train", nb_to_use=-1): + def policy_batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input - targets = self.train_policies if split == "train" else self.test_policies + policies = self.train_policies if split == "train" else self.test_policies input = input[:, : self.height * self.width] - targets = targets * (input != maze.v_wall)[:, :, None] + policies = policies * (input != maze.v_wall)[:, None] if nb_to_use > 0: input = input[:nb_to_use] - targets = targets[:nb_to_use] + policies = policies[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" for batch in tqdm.tqdm( - zip(input.split(self.batch_size), targets.split(self.batch_size)), + zip(input.split(self.batch_size), policies.split(self.batch_size)), dynamic_ncols=True, - desc=f"epoch-{split}", + desc=desc, ): yield batch @@ -348,7 +424,11 @@ class TaskMaze(Task): ar_mask = result.new_zeros(result.size()) ar_mask[:, self.height * self.width :] = 1 result *= 1 - ar_mask - masked_inplace_autoregression(model, self.batch_size, result, ar_mask) + order = random_order(result, self.height * self.width) + masked_inplace_autoregression( + model, self.batch_size, result, ar_mask, order=order + ) + result = shuffle(result, order, reorder=True) mazes, paths = self.seq2map(result) nb_correct += maze.path_correctness(mazes, paths).long().sum() nb_total += mazes.size(0) @@ -493,12 +573,6 @@ log_string(f"learning_rate_schedule {learning_rate_schedule}") ############################## -if args.one_shot: - one_shot(model, task) - exit(0) - -############################## - if nb_epochs_finished >= args.nb_epochs: n_epoch = nb_epochs_finished train_perplexity = compute_perplexity(model, split="train") @@ -510,8 +584,6 @@ if nb_epochs_finished >= args.nb_epochs: task.produce_results(n_epoch, model) - exit(0) - ############################## for n_epoch in range(nb_epochs_finished, args.nb_epochs): @@ -526,7 +598,7 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): elif args.optim == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) else: - raise ValueError(f"Unknown optimizer {args.optim}.") + raise ValueError(f"{args.optim=}") model.train() @@ -534,7 +606,9 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split="train"): input = input.to(device) - output = model(mygpt.BracketedSequence(input)).x + order = random_order(input, task.height * task.width) + input = shuffle(input, order) + output = model(mygpt.BracketedSequence(input), order=order).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -566,3 +640,8 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): log_string(f"saved checkpoint {checkpoint_name}") ###################################################################### + +if args.oneshot: + oneshot(model, task) + +######################################################################