From d63c681fdb2d6b5590991eaa4a2d9a5376678c67 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 20 Mar 2023 14:06:24 +0100 Subject: [PATCH] Update --- beaver.py | 45 +++++++++++++++++++++++---------------------- maze.py | 19 +++++++++++-------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/beaver.py b/beaver.py index e7decd1..33d174d 100755 --- a/beaver.py +++ b/beaver.py @@ -188,16 +188,19 @@ def one_shot(gpt, task): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) acc_train_loss, nb_train_samples = 0, 0 - for input, targets in task.policy_batches(split="train"): + for input, policies in task.policy_batches(split="train"): + #### + # print(f'{input.size()=} {policies.size()=}') + # s = maze.stationary_densities( + # exit(0) + #### + mask = input.unsqueeze(-1) == maze.v_empty output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_mode).x output = model(output_gpt) - targets = targets * (input.unsqueeze(-1) == maze.v_empty) - output = output * (input.unsqueeze(-1) == maze.v_empty) + targets = policies.permute(0, 2, 1) * mask + output = output * mask # loss = (output.softmax(-1) - targets).abs().max(-1).values.mean() - loss = ( - -(output.log_softmax(-1) * targets).sum() - / (input == maze.v_empty).sum() - ) + loss = -(output.log_softmax(-1) * targets).sum() / mask.sum() acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -206,16 +209,14 @@ def one_shot(gpt, task): optimizer.step() acc_test_loss, nb_test_samples = 0, 0 - for input, targets in task.policy_batches(split="test"): + for input, policies in task.policy_batches(split="test"): + mask = input.unsqueeze(-1) == maze.v_empty output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_mode).x output = model(output_gpt) - targets = targets * (input.unsqueeze(-1) == maze.v_empty) - output = output * (input.unsqueeze(-1) == maze.v_empty) + targets = policies.permute(0, 2, 1) * mask + output = output * mask # loss = (output.softmax(-1) - targets).abs().max(-1).values.mean() - loss = ( - -(output.log_softmax(-1) * targets).sum() - / (input == maze.v_empty).sum() - ) + loss = -(output.log_softmax(-1) * targets).sum() / mask.sum() acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -225,11 +226,11 @@ def one_shot(gpt, task): # ------------------- input = task.test_input[:32, : task.height * task.width] - targets = task.test_policies[:32] + targets = task.test_policies[:32].permute(0, 2, 1) output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_mode).x output = model(output_gpt) # losses = (-output.log_softmax(-1) * targets + targets.xlogy(targets)).sum(-1) - # losses = losses * (input == maze.v_empty) + # losses = losses * mask # losses = losses / losses.max() # losses = (output.softmax(-1) - targets).abs().max(-1).values # losses = (losses >= 0.05).float() @@ -300,7 +301,7 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) - self.train_policies = train_policies.flatten(-2).permute(0, 2, 1).to(device) + self.train_policies = train_policies.flatten(-2).to(device) test_mazes, test_paths, test_policies = maze.create_maze_data( nb_test_samples, @@ -310,7 +311,7 @@ class TaskMaze(Task): progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) - self.test_policies = test_policies.flatten(-2).permute(0, 2, 1).to(device) + self.test_policies = test_policies.flatten(-2).to(device) self.nb_codes = self.train_input.max() + 1 @@ -327,16 +328,16 @@ class TaskMaze(Task): def policy_batches(self, split="train", nb_to_use=-1): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input - targets = self.train_policies if split == "train" else self.test_policies + policies = self.train_policies if split == "train" else self.test_policies input = input[:, : self.height * self.width] - targets = targets * (input != maze.v_wall)[:, :, None] + policies = policies * (input != maze.v_wall)[:, None] if nb_to_use > 0: input = input[:nb_to_use] - targets = targets[:nb_to_use] + policies = policies[:nb_to_use] for batch in tqdm.tqdm( - zip(input.split(self.batch_size), targets.split(self.batch_size)), + zip(input.split(self.batch_size), policies.split(self.batch_size)), dynamic_ncols=True, desc=f"epoch-{split}", ): diff --git a/maze.py b/maze.py index 36eef25..d09e860 100755 --- a/maze.py +++ b/maze.py @@ -110,19 +110,22 @@ def compute_policy(walls, goal_i, goal_j): return proba -def stationary_density(policy, start_i, start_j): - probas = policy.new_zeros(policy.size()[:-1]) +def stationary_densities(mazes, policies): + start = (mazes == v_start).nonzero(as_tuple=True) + probas = mazes.new_zeros(mazes.size()) pred_probas = probas.clone() - probas[start_i, start_j] = 1.0 + probas[start] = 1.0 while not pred_probas.equal(probas): pred_probas.copy_(probas) probas.zero_() - probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :] - probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :] - probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1] - probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:] - probas[start_i, start_j] = 1.0 + probas[:, 1:, :] = pred_probas[:, :-1, :] * policies[:, 0, :-1, :] + probas[:, :-1, :] = pred_probas[:, 1:, :] * policies[:, 1, 1:, :] + probas[:, :, 1:] = pred_probas[:, :, :-1] * policies[:, 2, :, :-1] + probas[:, :, :-1] = pred_probas[:, :, 1:] * policies[:, 3, :, 1:] + probas[start] = 1.0 + + return probas ###################################################################### -- 2.20.1