X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=e741094c1543483514aae45e391d1d3573030a2b;hb=8f806c62be589d2837f88ffca084ed2ae833124c;hp=6d9f69d65120dc5152c4a38ed2cf72a7a7682b72;hpb=943a440a83b98de60bad767a9ad09f63b5088514;p=picoclvr.git diff --git a/main.py b/main.py index 6d9f69d..e741094 100755 --- a/main.py +++ b/main.py @@ -5,6 +5,9 @@ # Written by Francois Fleuret +# torch.backends.cuda.matmul.allow_tf23 +# torch.autocast(torch.bfloat16) + import math, sys, argparse, time, tqdm, itertools, os import torch, torchvision @@ -15,12 +18,16 @@ import mygpt, tensorstack ###################################################################### -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cuda.matmul.allow_tf32 = True +else: + device = torch.device("cpu") ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache to solve a toy geometric reasonning task." + description="An implementation of GPT with cache to solve a toy geometric reasoning task." ) parser.add_argument("--log_filename", type=str, default="train.log") @@ -55,8 +62,6 @@ parser.add_argument("--nb_blocks", type=int, default=12) parser.add_argument("--dropout", type=float, default=0.1) -parser.add_argument("--nb_oneshot_blocks", type=int, default=-1) - parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--no_checkpoint", action="store_true", default=False) @@ -89,7 +94,7 @@ except FileExistsError: print(f"result directory {args.result_dir} already exists") exit(1) -log_file = open(os.path.join(args.result_dir, args.log_filename), "w") +log_file = open(os.path.join(args.result_dir, args.log_filename), "a") if args.seed >= 0: # torch.backends.cudnn.deterministic = True @@ -421,9 +426,7 @@ class TaskPicoCLVR(Task): f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%" ) - img = picoclvr.descr2img( - result_descr, [0], height=self.height, width=self.width - ) + img = picoclvr.descr2img(result_descr, height=self.height, width=self.width) if img.dim() == 5: if img.size(1) == 1: