Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 09:48:04 +0000 (11:48 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 09:48:04 +0000 (11:48 +0200)
main.py

diff --git a/main.py b/main.py
index dede204..8010fa4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -57,7 +57,7 @@ parser.add_argument("--inference_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=25000)
 
-parser.add_argument("--nb_test_samples", type=int, default=10000)
+parser.add_argument("--nb_test_samples", type=int, default=1000)
 
 parser.add_argument("--nb_train_alien_samples", type=int, default=0)
 
@@ -719,7 +719,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
         x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
 
     x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-    logits_hat_x_0 = model(x_t_with_mask)
+
+    with torch.cuda.amp.autocast():
+        logits_hat_x_0 = model(x_t_with_mask)
 
     return logits_hat_x_0
 
@@ -743,7 +745,8 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None
 
     for it in range(nb_iterations_max):
         x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-        logits = model(x_t_with_mask)
+        with torch.cuda.amp.autocast():
+            logits = model(x_t_with_mask)
         logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
         dist = torch.distributions.categorical.Categorical(logits=logits)
 
@@ -891,6 +894,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
+    scaler = torch.cuda.amp.GradScaler()
+
     for x_0, mask_generate in ae_batches(
         quiz_machine,
         args.nb_train_samples,
@@ -905,18 +910,21 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        logits = logits_hat_x_0_from_random_iteration(
-            model, x_0, mask_generate, prompt_noise=args.prompt_noise
-        )
+        with torch.cuda.amp.autocast():
+            logits = logits_hat_x_0_from_random_iteration(
+                model, x_0, mask_generate, prompt_noise=args.prompt_noise
+            )
 
         loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
         acc_train_loss += loss.item() * x_0.size(0)
         nb_train_samples += x_0.size(0)
 
-        loss.backward()
+        scaler.scale(loss).backward()
 
         if nb_train_samples % args.batch_size == 0:
-            model.optimizer.step()
+            scaler.step(model.optimizer)
+
+        scaler.update()
 
     log_string(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"