+def compute_perplexity(model, task, prompt_len, split="train"):
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for input in task.batches(split=split):
+ input = input.to(device)
+ output = eval_mygpt(model, input, prompt_len=prompt_len)
+ if args.noncausal_prompt:
+ d = input.size(1) // 2
+ loss = F.cross_entropy(output[:, d:].transpose(1, 2), input[:, d:])
+ else:
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_loss += loss.item() * input.size(0)
+ nb_samples += input.size(0)
+
+ model.train(t)
+
+ return math.exp(min(100, acc_loss / nb_samples))
+
+
+######################################################################
+
+
+def oneshot_policy_loss(mazes, output, policies, height, width):
+ masks = (mazes == maze.v_empty).unsqueeze(-1)
+ targets = policies.permute(0, 2, 1) * masks
+ output = output * masks
+ return -(output.log_softmax(-1) * targets).sum() / masks.sum()
+
+
+def oneshot_trace_loss(mazes, output, policies, height, width):
+ masks = mazes == maze.v_empty
+ targets = maze.stationary_densities(
+ mazes.view(-1, height, width), policies.view(-1, 4, height, width)
+ ).flatten(-2)
+ targets = targets * masks
+ output = output.squeeze(-1) * masks
+ return (output - targets).abs().sum() / masks.sum()
+
+
+def oneshot(model, learning_rate_scheduler, task):
+ t = model.training
+ model.eval()
+ mazes = task.test_input[:32].clone()
+ mazes[:, task.height * task.width :] = 0
+ policies = task.test_policies[:32]
+ targets = maze.stationary_densities(
+ mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
+ output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
+ output = F.softmax(output, dim=2)
+ print(f"{output.size()=}")
+ proba_path = output[:, task.height * task.width :, 4].reshape(
+ -1, task.height, task.width
+ )
+ mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
+ paths = task.test_input[:32, task.height * task.width :].reshape(
+ -1, task.height, task.width
+ )
+ filename = f"oneshot.png"
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ # target_paths=paths,
+ score_paths=proba_path,
+ score_truth=targets,
+ )
+ log_string(f"wrote {filename}")
+
+
+def oneshot_old(gpt, learning_rate_scheduler, task):
+ t = gpt.training
+ gpt.eval()
+
+ if args.oneshot_input == "head":
+ dim_in = args.dim_model
+ elif args.oneshot_input == "deep":
+ dim_in = args.dim_model * args.nb_blocks * 2
+ else:
+ raise ValueError(f"{args.oneshot_input=}")
+
+ if args.oneshot_output == "policy":
+ dim_out = 4
+ compute_loss = oneshot_policy_loss
+ elif args.oneshot_output == "trace":
+ dim_out = 1
+ compute_loss = oneshot_trace_loss
+ else:
+ raise ValueError(f"{args.oneshot_output=}")
+
+ model = nn.Sequential(
+ nn.Linear(dim_in, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, args.dim_model),
+ nn.ReLU(),
+ nn.Linear(args.dim_model, dim_out),
+ ).to(device)
+
+ learning_rate_scheduler.reset()
+
+ for n_epoch in range(args.nb_epochs):
+ learning_rate = learning_rate_scheduler.get_learning_rate()
+ log_string(f"learning_rate {n_epoch} {learning_rate}")
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+
+ acc_train_loss, nb_train_samples = 0, 0
+ for mazes, policies in task.policy_batches(split="train"):
+ output_gpt = eval_mygpt(
+ gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
+ )
+ output = model(output_gpt)
+
+ loss = compute_loss(mazes, output, policies, task.height, task.width)
+ acc_train_loss += loss.item() * mazes.size(0)
+ nb_train_samples += mazes.size(0)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ learning_rate_scheduler.update(n_epoch + 1, acc_train_loss)
+
+ acc_test_loss, nb_test_samples = 0, 0
+ for mazes, policies in task.policy_batches(split="test"):
+ output_gpt = eval_mygpt(
+ gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
+ )
+ output = model(output_gpt)
+ loss = compute_loss(mazes, output, policies, task.height, task.width)
+ acc_test_loss += loss.item() * mazes.size(0)
+ nb_test_samples += mazes.size(0)
+
+ log_string(
+ f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
+ )
+
+ # -------------------
+ mazes = task.test_input[:32, : task.height * task.width]
+ policies = task.test_policies[:32]
+ output_gpt = eval_mygpt(
+ gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
+ )
+ output = model(output_gpt)
+ if args.oneshot_output == "policy":
+ targets = policies.permute(0, 2, 1)
+ scores = (
+ (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
+ ).float()
+ elif args.oneshot_output == "trace":
+ targets = maze.stationary_densities(
+ mazes.view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
+ scores = output
+ else:
+ raise ValueError(f"{args.oneshot_output=}")
+
+ scores = scores.reshape(-1, task.height, task.width)
+ mazes = mazes.reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
+ filename = (
+ f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
+ )
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ score_paths=scores,
+ score_truth=targets,
+ )
+ log_string(f"wrote {filename}")
+
+ # -------------------
+
+ gpt.train(t)
+
+
+######################################################################
+
+
+class LearningRateScheduler:
+ def get_learning_rate(self):
+ pass
+
+ def update(self, nb_finished_epochs, loss):
+ pass
+
+ def reset(self):
+ pass
+
+ def get_state(self):
+ return vars(self)
+
+ def set_state(self, state):
+ print(f"{state=}")
+ for k, v in state.items():
+ setattr(self, k, v)
+
+
+class StepWiseScheduler(LearningRateScheduler):
+ def __init__(self, schedule):
+ self.nb_finished_epochs = 0
+ self.schedule = schedule
+
+ def get_learning_rate(self):
+ return self.schedule[self.nb_finished_epochs]
+
+ def update(self, nb_finished_epochs, loss):
+ self.nb_finished_epochs = nb_finished_epochs
+
+ def reset(self):
+ self.nb_finished_epochs = 0
+
+ def get_state(self):
+ return {"nb_finished_epochs": self.nb_finished_epochs}
+
+
+class AutoScheduler(LearningRateScheduler):
+ def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
+ self.learning_rate_init = learning_rate_init
+ self.learning_rate = learning_rate_init
+ self.growth = growth
+ self.degrowth = degrowth
+ self.pred_loss = None
+
+ def get_learning_rate(self):
+ return self.learning_rate
+
+ def update(self, nb_finished_epochs, loss):
+ if self.pred_loss is not None:
+ if loss >= self.pred_loss:
+ self.learning_rate *= self.degrowth
+ else:
+ self.learning_rate *= self.growth
+ self.pred_loss = loss
+
+ def reset(self):
+ self.learning_rate = self.learning_rate_init
+
+ def get_state(self):
+ return {
+ "learning_rate_init": self.learning_rate_init,
+ "pred_loss": self.pred_loss,
+ }
+
+
+######################################################################
+
+