x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- # with torch.amp.autocast("cuda"):
- logits_hat_x_0 = model(x_t_with_mask)
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits_hat_x_0 = model(x_t_with_mask)
return logits_hat_x_0
for it in range(self.nb_iterations):
x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- # with torch.amp.autocast("cuda"):
- logits = model(x_t_with_mask)
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(x_t_with_mask)
# logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
dist = torch.distributions.categorical.Categorical(logits=logits)
# torch.set_float32_matmul_precision("high")
+# torch.set_default_dtype(torch.bfloat16)
+
import diffusion
######################################################################
nb_train_samples, acc_train_loss = 0, 0.0
- # scaler = torch.amp.GradScaler("cuda")
+ scaler = torch.amp.GradScaler("cuda")
for x_0, mask_generate in ae_batches(
quiz_machine,
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- # with torch.amp.autocast("cuda"):
- logits = diffuser.logits_hat_x_0_from_random_iteration(
- model=model,
- x_0=x_0,
- mask_generate=mask_generate,
- prompt_noise=args.prompt_noise,
- )
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = diffuser.logits_hat_x_0_from_random_iteration(
+ model=model,
+ x_0=x_0,
+ mask_generate=mask_generate,
+ prompt_noise=args.prompt_noise,
+ )
loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
acc_train_loss += loss.item() * x_0.size(0)
nb_train_samples += x_0.size(0)
- loss.backward()
+ # loss.backward()
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
+ # if nb_train_samples % args.batch_size == 0:
+ # model.optimizer.step()
- # scaler.scale(loss).backward()
+ scaler.scale(loss).backward()
- # if nb_train_samples % args.batch_size == 0:
- # scaler.step(model.optimizer)
+ if nb_train_samples % args.batch_size == 0:
+ scaler.step(model.optimizer)
- # scaler.update()
+ scaler.update()
log_string(
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"