# Written by Francois Fleuret <francois@fleuret.org>
+# torch.backends.cuda.matmul.allow_tf23
+# torch.autocast(torch.bfloat16)
+
import math, sys, argparse, time, tqdm, itertools, os
import torch, torchvision
######################################################################
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
######################################################################
parser.add_argument("--dropout", type=float, default=0.1)
-parser.add_argument("--nb_oneshot_blocks", type=int, default=-1)
-
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
parser.add_argument("--no_checkpoint", action="store_true", default=False)