Update.
[beaver.git] / beaver.py
index 4f694da..dfbb7b6 100755 (executable)
--- a/beaver.py
+++ b/beaver.py
@@ -170,8 +170,11 @@ def compute_perplexity(model, split="train"):
 
 
 def one_shot(gpt, task):
-    pass
-
+    t = gpt.training
+    gpt.eval()
+    for input, targets in task.policy_batches():
+        output = gpt(mygpt.BracketedSequence(input), with_readout = False).x
+    gpt.train(t)
 
 ######################################################################
 
@@ -215,25 +218,25 @@ class TaskMaze(Task):
         self.width = width
         self.device = device
 
-        mazes_train, paths_train = maze.create_maze_data(
+        train_mazes, train_paths, train_policies = maze.create_maze_data(
             nb_train_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
         )
-        mazes_train, paths_train = mazes_train.to(device), paths_train.to(device)
-        self.train_input = self.map2seq(mazes_train, paths_train)
+        self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
+        self.train_policies = train_policies.to(device)
 
-        mazes_test, paths_test = maze.create_maze_data(
+        test_mazes, test_paths, test_policies = maze.create_maze_data(
             nb_test_samples,
             height=height,
             width=width,
             nb_walls=nb_walls,
             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
         )
-        mazes_test, paths_test = mazes_test.to(device), paths_test.to(device)
-        self.test_input = self.map2seq(mazes_test, paths_test)
+        self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
+        self.test_policies = test_policies.to(device)
 
         self.nb_codes = self.train_input.max() + 1
 
@@ -247,6 +250,24 @@ class TaskMaze(Task):
         ):
             yield batch
 
+    def policy_batches(self, split="train", nb_to_use=-1):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        targets = self.train_policies if split == "train" else self.test_policies
+        input = input[:, : self.height * self.width]
+        targets = targets.flatten(-2) * (input != maze.v_wall)[:,None]
+
+        if nb_to_use > 0:
+            input = input[:nb_to_use]
+            targets = targets[:nb_to_use]
+
+        for batch in tqdm.tqdm(
+            zip(input.split(self.batch_size), targets.split(self.batch_size)),
+            dynamic_ncols=True,
+            desc=f"epoch-{split}",
+        ):
+            yield batch
+
     def vocabulary_size(self):
         return self.nb_codes