Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 04e5652..79841f3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -133,6 +133,10 @@ parser.add_argument("--rpl_no_prog", action="store_true", default=False)
 
 parser.add_argument("--grid_size", type=int, default=6)
 
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
 ##############################
 # picoclvr options
 
@@ -701,6 +705,8 @@ elif args.task == "grid":
         nb_test_samples=args.nb_test_samples,
         batch_size=args.batch_size,
         size=args.grid_size,
+        nb_shapes=args.grid_nb_shapes,
+        nb_colors=args.grid_nb_colors,
         logger=log_string,
         device=device_data,
     )
@@ -835,21 +841,22 @@ if args.max_percents_of_test_in_train >= 0:
 
 ##############################
 
-for input in task.batches(split="train", desc="calibrate"):
-    input = input.to(device)
-    output = model(mygpt.BracketedSequence(input)).x
-
-for n, m in model.named_modules():
-    for a in dir(m):
-        x = getattr(m, a)
-        if isinstance(x, mygpt.Calibrator):
-            print(f"####### ${n} | ${a} ########################")
-            mean, std = x.moments()
-            print("mean\n", mean, "\n")
-            print("std\n", std, "\n")
-            print(f"############################################\n\n")
-
-exit(0)
+if "calibrate" in sup_args:
+    for input in task.batches(split="train", desc="calibrate"):
+        input = input.to(device)
+        output = model(mygpt.BracketedSequence(input)).x
+
+    for n, m in model.named_modules():
+        for a in dir(m):
+            x = getattr(m, a)
+            if isinstance(x, mygpt.Calibrator):
+                print(f"####### ${n} | ${a} ########################")
+                mean, std = x.moments()
+                print("mean\n", mean, "\n")
+                print("std\n", std, "\n")
+                print(f"############################################\n\n")
+
+    exit(0)
 
 ##############################