parser.add_argument("--nb_train_samples", type=int, default=25000)
-parser.add_argument("--nb_test_samples", type=int, default=10000)
+parser.add_argument("--nb_test_samples", type=int, default=1000)
parser.add_argument("--nb_train_alien_samples", type=int, default=0)
x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- logits_hat_x_0 = model(x_t_with_mask)
+
+ with torch.cuda.amp.autocast():
+ logits_hat_x_0 = model(x_t_with_mask)
return logits_hat_x_0
for it in range(nb_iterations_max):
x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- logits = model(x_t_with_mask)
+ with torch.cuda.amp.autocast():
+ logits = model(x_t_with_mask)
logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
dist = torch.distributions.categorical.Categorical(logits=logits)
nb_train_samples, acc_train_loss = 0, 0.0
+ scaler = torch.cuda.amp.GradScaler()
+
for x_0, mask_generate in ae_batches(
quiz_machine,
args.nb_train_samples,
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
- )
+ with torch.cuda.amp.autocast():
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
+ )
loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
acc_train_loss += loss.item() * x_0.size(0)
nb_train_samples += x_0.size(0)
- loss.backward()
+ scaler.scale(loss).backward()
if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
+ scaler.step(model.optimizer)
+
+ scaler.update()
log_string(
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"