Update.
[picoclvr.git] / main.py
diff --git a/main.py b/main.py
index 6d9f69d..e741094 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -5,6 +5,9 @@
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# torch.backends.cuda.matmul.allow_tf23
+# torch.autocast(torch.bfloat16)
+
 import math, sys, argparse, time, tqdm, itertools, os
 
 import torch, torchvision
@@ -15,12 +18,16 @@ import mygpt, tensorstack
 
 ######################################################################
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
 
 ######################################################################
 
 parser = argparse.ArgumentParser(
-    description="An implementation of GPT with cache to solve a toy geometric reasonning task."
+    description="An implementation of GPT with cache to solve a toy geometric reasoning task."
 )
 
 parser.add_argument("--log_filename", type=str, default="train.log")
@@ -55,8 +62,6 @@ parser.add_argument("--nb_blocks", type=int, default=12)
 
 parser.add_argument("--dropout", type=float, default=0.1)
 
-parser.add_argument("--nb_oneshot_blocks", type=int, default=-1)
-
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
@@ -89,7 +94,7 @@ except FileExistsError:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
-log_file = open(os.path.join(args.result_dir, args.log_filename), "w")
+log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
 
 if args.seed >= 0:
     # torch.backends.cudnn.deterministic = True
@@ -421,9 +426,7 @@ class TaskPicoCLVR(Task):
             f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
         )
 
-        img = picoclvr.descr2img(
-            result_descr, [0], height=self.height, width=self.width
-        )
+        img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
 
         if img.dim() == 5:
             if img.size(1) == 1: