From 39e24a2f9076db2d512791e723e7f2dc0275d99c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 13 Mar 2023 21:23:20 +0100 Subject: [PATCH] Update. --- beaver.py | 37 +++++++++++++++++++++++++++++-------- maze.py | 22 ++++++++++++---------- mygpt.py | 4 ++-- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/beaver.py b/beaver.py index 4f694da..dfbb7b6 100755 --- a/beaver.py +++ b/beaver.py @@ -170,8 +170,11 @@ def compute_perplexity(model, split="train"): def one_shot(gpt, task): - pass - + t = gpt.training + gpt.eval() + for input, targets in task.policy_batches(): + output = gpt(mygpt.BracketedSequence(input), with_readout = False).x + gpt.train(t) ###################################################################### @@ -215,25 +218,25 @@ class TaskMaze(Task): self.width = width self.device = device - mazes_train, paths_train = maze.create_maze_data( + train_mazes, train_paths, train_policies = maze.create_maze_data( nb_train_samples, height=height, width=width, nb_walls=nb_walls, progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"), ) - mazes_train, paths_train = mazes_train.to(device), paths_train.to(device) - self.train_input = self.map2seq(mazes_train, paths_train) + self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device)) + self.train_policies = train_policies.to(device) - mazes_test, paths_test = maze.create_maze_data( + test_mazes, test_paths, test_policies = maze.create_maze_data( nb_test_samples, height=height, width=width, nb_walls=nb_walls, progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"), ) - mazes_test, paths_test = mazes_test.to(device), paths_test.to(device) - self.test_input = self.map2seq(mazes_test, paths_test) + self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device)) + self.test_policies = test_policies.to(device) self.nb_codes = self.train_input.max() + 1 @@ -247,6 +250,24 @@ class TaskMaze(Task): ): yield batch + def policy_batches(self, split="train", nb_to_use=-1): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + targets = self.train_policies if split == "train" else self.test_policies + input = input[:, : self.height * self.width] + targets = targets.flatten(-2) * (input != maze.v_wall)[:,None] + + if nb_to_use > 0: + input = input[:nb_to_use] + targets = targets[:nb_to_use] + + for batch in tqdm.tqdm( + zip(input.split(self.batch_size), targets.split(self.batch_size)), + dynamic_ncols=True, + desc=f"epoch-{split}", + ): + yield batch + def vocabulary_size(self): return self.nb_codes diff --git a/maze.py b/maze.py index cfdede3..d11ab6e 100755 --- a/maze.py +++ b/maze.py @@ -113,18 +113,16 @@ def compute_policy(walls, i, j): ###################################################################### -def mark_path(walls, i, j, goal_i, goal_j): - policy = compute_policy(walls, goal_i, goal_j) +def mark_path(walls, i, j, goal_i, goal_j, policy): action = torch.distributions.categorical.Categorical( policy.permute(1, 2, 0) ).sample() - walls[i, j] = 4 n, nmax = 0, walls.numel() while i != goal_i or j != goal_j: di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]] i, j = i + di, j + dj assert walls[i, j] == 0 - walls[i, j] = 4 + walls[i, j] = v_path n += 1 assert n < nmax @@ -160,6 +158,7 @@ def create_maze_data( ): mazes = torch.empty(nb, height, width, dtype=torch.int64) paths = torch.empty(nb, height, width, dtype=torch.int64) + policies = torch.empty(nb, 4, height, width, dtype=torch.int64) for n in progress_bar(range(nb)): maze = create_maze(height, width, nb_walls) @@ -168,18 +167,21 @@ def create_maze_data( start, goal = i[torch.randperm(i.size(0))[:2]] if (start - goal).abs().sum() >= dist_min: break + start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1] + policy = compute_policy(maze, goal_i, goal_j) path = maze.clone() - mark_path(path, start[0], start[1], goal[0], goal[1]) - maze[start[0], start[1]] = v_start - maze[goal[0], goal[1]] = v_goal - path[start[0], start[1]] = v_start - path[goal[0], goal[1]] = v_goal + mark_path(path, start_i, start_j, goal_i, goal_j, policy) + maze[start_i, start_j] = v_start + maze[goal_i, goal_j] = v_goal + path[start_i, start_j] = v_start + path[goal_i, goal_j] = v_goal mazes[n] = maze paths[n] = path + policies[n] = policy - return mazes, paths + return mazes, paths, policies ###################################################################### diff --git a/mygpt.py b/mygpt.py index a0f3dbf..d424eef 100755 --- a/mygpt.py +++ b/mygpt.py @@ -246,11 +246,11 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) - def forward(self, bs): + def forward(self, bs, with_readout = True): bs.x = F.pad(bs.x, (1, -1)) bs = self.embedding(bs) bs = self.trunk(bs) - bs = self.readout(bs) + if with_readout: bs = self.readout(bs) return bs -- 2.39.5