Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 9679236..db982ca 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -102,7 +102,7 @@ parser.add_argument("--snake_width", type=int, default=8)
 
 parser.add_argument("--snake_nb_colors", type=int, default=5)
 
-parser.add_argument("--snake_length", type=int, default=400)
+parser.add_argument("--snake_length", type=int, default=200)
 
 ######################################################################
 
@@ -143,8 +143,8 @@ default_args = {
         "batch_size": 25,
     },
     "snake": {
-        "nb_epochs": 25,
-        "batch_size": 20,
+        "nb_epochs": 5,
+        "batch_size": 25,
     },
 }
 
@@ -173,15 +173,27 @@ for n in vars(args):
 ######################################################################
 
 
+# ra_mask is boolean, with 1s on the values to generate
+
+
 def masked_inplace_autoregression(
-    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
 ):
-    for input, ar_mask in tqdm.tqdm(
-        zip(input.split(batch_size), ar_mask.split(batch_size)),
-        dynamic_ncols=True,
-        desc="autoregression",
-        total=input.size(0) // batch_size,
-    ):
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+    if progress_bar_desc is not None:
+        tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=input.size(0) // batch_size,
+        )
+    for input, ar_mask in batches:
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
@@ -317,6 +329,7 @@ class TaskPicoCLVR(Task):
                 input,
                 ar_masks,
                 forbidden_tokens,
+                progress_bar_desc=None,
                 device=self.device,
             )
             model.train(t)
@@ -689,7 +702,7 @@ class TaskSnake(Task):
         self.device = device
         self.prompt_length = prompt_length
 
-        self.train_input, self.train_prior_visits = snake.generate_sequences(
+        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
             nb_train_samples,
             height,
             width,
@@ -698,7 +711,7 @@ class TaskSnake(Task):
             prompt_length,
             self.device,
         )
-        self.test_input, self.test_prior_visits = snake.generate_sequences(
+        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
             nb_test_samples,
             height,
             width,
@@ -975,9 +988,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)