From dc67c779ee5a07ae4dde34e827755020bfcd71e0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 15 Sep 2024 12:40:21 +0200 Subject: [PATCH] Update. --- main.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 19a3fee..49799e4 100755 --- a/main.py +++ b/main.py @@ -59,7 +59,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None) parser.add_argument("--inference_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=25000) +parser.add_argument("--nb_train_samples", type=int, default=100000) parser.add_argument("--nb_test_samples", type=int, default=1000) @@ -567,7 +567,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 - scaler = torch.amp.GradScaler("cuda") + # scaler = torch.amp.GradScaler("cuda") for x_0, mask_generate in ae_batches( quiz_machine, @@ -595,17 +595,17 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi acc_train_loss += loss.item() * x_0.size(0) nb_train_samples += x_0.size(0) - # loss.backward() + loss.backward() - # if nb_train_samples % args.batch_size == 0: - # model.optimizer.step() + if nb_train_samples % args.batch_size == 0: + model.optimizer.step() - scaler.scale(loss).backward() + # scaler.scale(loss).backward() - if nb_train_samples % args.batch_size == 0: - scaler.step(model.optimizer) + # if nb_train_samples % args.batch_size == 0: + # scaler.step(model.optimizer) - scaler.update() + # scaler.update() log_string( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" -- 2.39.5