Update.
[mygptrnn.git] / main.py
diff --git a/main.py b/main.py
index 74e1d6c..969b47f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,14 +16,6 @@ import mygpt, tasks, problems
 
 ######################################################################
 
-if torch.cuda.is_available():
-    device = torch.device("cuda")
-    torch.backends.cuda.matmul.allow_tf32 = True
-else:
-    device = torch.device("cpu")
-
-######################################################################
-
 
 def str2bool(x):
     x = x.lower()
@@ -55,6 +47,8 @@ parser.add_argument("--seed", type=int, default=0)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
 
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
 ########################################
 
 parser.add_argument("--nb_epochs", type=int, default=50)
@@ -117,7 +111,7 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
 
 parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument("--overwrite_results", action="store_true", default=False)
+parser.add_argument("--continue_training", action="store_true", default=False)
 
 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
@@ -217,6 +211,14 @@ if args.result_dir is None:
 
 ######################################################################
 
+if not args.force_cpu and torch.cuda.is_available():
+    device = torch.device("cuda")
+    torch.backends.cuda.matmul.allow_tf32 = True
+else:
+    device = torch.device("cpu")
+
+######################################################################
+
 default_task_args = {
     "addition": {
         "model": "352M",
@@ -426,7 +428,7 @@ else:
 try:
     os.mkdir(args.result_dir)
 except FileExistsError:
-    if not args.overwrite_results:
+    if not args.continue_training:
         print(f"result directory {args.result_dir} already exists")
         exit(1)
 
@@ -832,7 +834,7 @@ if nb_epochs_finished >= nb_epochs:
         deterministic_synthesis=args.deterministic_synthesis,
     )
 
-time_pred_result = None
+time_pred_result = datetime.datetime.now()
 
 it = 0
 
@@ -910,10 +912,9 @@ for n_epoch in range(nb_epochs_finished, nb_epochs):
         )
 
         time_current_result = datetime.datetime.now()
-        if time_pred_result is not None:
-            log_string(
-                f"next_result {time_current_result + (time_current_result - time_pred_result)}"
-            )
+        log_string(
+            f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+        )
         time_pred_result = time_current_result
 
     checkpoint = {