Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 09:30:14 +0000 (11:30 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 09:30:14 +0000 (11:30 +0200)
diffusion.py [new file with mode: 0755]
main.py

diff --git a/diffusion.py b/diffusion.py
new file mode 100755 (executable)
index 0000000..98d8d0a
--- /dev/null
@@ -0,0 +1,137 @@
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+
+def NTC_channel_cat(*x):
+    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
+class Diffuser:
+    def __init__(self, mu_T_sampler, nb_iterations, proba_corruption):
+        self.mu_T_sampler = mu_T_sampler
+        self.nb_iterations = nb_iterations
+        self.proba_corruption = proba_corruption
+
+    def sample_x_t_given_x_0(self, x_0, t):
+        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+        r = torch.rand(x_0.size(), device=x_0.device)
+        proba_erased = 1 - (1 - self.proba_corruption) ** t
+        mask_erased = (r <= proba_erased[:, None]).long()
+        x_t = (1 - mask_erased) * x_0 + mask_erased * noise
+
+        return x_t
+
+    # This function returns a 2d tensor of same shape as low, full of
+    # uniform random values in [0,1], such that, in every row, the values
+    # corresponding to the True in low are all lesser than the values
+    # corresponding to the False.
+
+    def prioritized_rand(self, low):
+        x = (
+            torch.rand(low.size(), device=low.device)
+            .sort(dim=1, descending=True)
+            .values
+        )
+        k = torch.rand(low.size(), device=low.device) + low.long()
+        k = k.sort(dim=1).indices
+        y = x.new(x.size())
+        y.scatter_(dim=1, index=k, src=x)
+        return y
+
+    def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t):
+        r = self.prioritized_rand(x_0 != x_t)
+        mask_changes = (r <= self.proba_corruption).long()
+        x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
+        return x_t_minus_1
+
+    ######################################################################
+
+    # This function gets a clean target x_0, and a mask indicating which
+    # part to generate (conditionnaly to the others), and returns the
+    # logits starting from a x_t|X_0=x_0 picked at random with t random
+
+    def logits_hat_x_0_from_random_iteration(
+        self, model, x_0, mask_generate, prompt_noise=0.0
+    ):
+        # We favor iterations near the clean signal
+
+        probs_iterations = 0.1 ** torch.linspace(
+            0, 1, self.nb_iterations, device=x_0.device
+        )
+
+        probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+        probs_iterations = probs_iterations.expand(x_0.size(0), -1)
+        dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+
+        t = dist.sample() + 1
+
+        x_t = self.sample_x_t_given_x_0(x_0, t)
+
+        # Only the part to generate is degraded, the rest is a perfect
+        # noise-free conditionning
+
+        x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
+
+        # We may inject noise to prevent high-complexity non-structure
+        # signal to be generated as a way of "increasing reasoning
+        # complexity"
+
+        if prompt_noise > 0:
+            mask_prompt_noise = (
+                torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
+            ).long()
+            noise = self.mu_T_sampler(x_t.size(), device=x_t.device)
+            noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
+            x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
+
+        x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+
+        # with torch.amp.autocast("cuda"):
+        logits_hat_x_0 = model(x_t_with_mask)
+
+        return logits_hat_x_0
+
+    ######################################################################
+
+    def ae_generate(self, model, x_0, mask_generate, mask_hints=None):
+        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+
+        single_iteration = (
+            mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+        ).long()[:, None]
+
+        if mask_hints is None:
+            mask_start = mask_generate
+        else:
+            mask_start = mask_generate * (1 - mask_hints)
+
+        x_t = (1 - mask_start) * x_0 + mask_start * noise
+
+        changed = True
+
+        for it in range(self.nb_iterations):
+            x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+            # with torch.amp.autocast("cuda"):
+            logits = model(x_t_with_mask)
+            # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+
+            hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
+
+            hat_x_t_minus_1 = single_iteration * hat_x_0 + (
+                1 - single_iteration
+            ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
+
+            if hat_x_t_minus_1.equal(x_t):
+                break
+            else:
+                changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
+                x_t[changed] = hat_x_t_minus_1[changed]
+
+        return x_t
diff --git a/main.py b/main.py
index 01ce963..534bab9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -23,9 +23,11 @@ from quiz_machine import one_batch_masked_inplace_autoregression
 
 import threading, subprocess
 
-import torch.multiprocessing as mp
+import torch.multiprocessing as mp
 
-torch.set_float32_matmul_precision("high")
+# torch.set_float32_matmul_precision("high")
+
+import diffusion
 
 ######################################################################
 
@@ -107,9 +109,7 @@ parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
 
-parser.add_argument("--diffusion_delta", type=float, default=0.05)
-
-parser.add_argument("--diffusion_epsilon", type=float, default=0.05)
+parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
@@ -322,6 +322,15 @@ quiz_machine = quiz_machine.QuizMachine(
     device=main_device,
 )
 
+
+def mu_T_sampler(shape, device="cpu"):
+    return torch.randint(quiz_machine.problem.nb_colors, shape, device=device)
+
+
+diffuser = diffusion.Diffuser(
+    mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption
+)
+
 ######################################################################
 
 log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
@@ -363,125 +372,53 @@ def optimizer_to(optim, device):
 
 ######################################################################
 
+# quad_order, quad_generate, quad_noise, quad_loss
 
-from mygpt import (
-    CachedWithResidual,
-    CacheWrapper,
-    CachedVaswaniPositionalEncoding,
-    QKVAttention,
-    BracketedSequence,
-)
-
-
-class MultiEmbedding(nn.Module):
-    def __init__(self, nb_values, dim):
-        super().__init__()
-        self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values])
-
-    def forward(self, x):
-        y = 0
-        for f, z in zip(self.embeddings, x.split(1, dim=2)):
-            y = y + f(z[:, :, 0])
-        return y
-
-
-def attention_block(dim_model, dim_keys, nb_heads, dropout):
-    return CachedWithResidual(
-        CacheWrapper(
-            nn.LayerNorm((dim_model,)),
-        ),
-        QKVAttention(
-            dim_in=dim_model,
-            dim_qk=dim_keys,
-            dim_v=dim_model // nb_heads,
-            nb_heads=nb_heads,
-            attention_dropout=dropout,
-        ),
-    )
-
-
-def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
-    return CachedWithResidual(
-        CacheWrapper(
-            nn.LayerNorm((dim_model,)),
-            nn.Linear(in_features=dim_model, out_features=dim_hidden),
-            nn.ReLU(),
-            nn.Linear(in_features=dim_hidden, out_features=dim_model),
-            nn.Dropout(dropout),
-        ),
-    )
-
-
-class MyAttentionAE(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        dropout=0.0,
-        len_max=1024,
-    ):
-        super().__init__()
+data_structures = [
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
 
-        assert dim_model % nb_heads == 0
 
-        self.embedding = CacheWrapper(
-            nn.Sequential(
-                MultiEmbedding((vocabulary_size, 2), dim_model),
-                nn.Dropout(dropout),
-            ),
-        )
+######################################################################
 
-        # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
-        self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5)
 
-        trunk_blocks = []
+def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
+    record = []
 
-        for b in range(nb_blocks):
-            trunk_blocks += [
-                attention_block(dim_model, dim_keys, nb_heads, dropout),
-                ffw_block(dim_model, dim_hidden, nb_heads, dropout),
-            ]
+    for x_0 in input.split(args.batch_size):
+        loss = 0
 
-        self.trunk = nn.Sequential(*trunk_blocks)
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+            )
+            logits = logits_hat_x_0_from_random_iteration(
+                model, x_0, mask_generate, prompt_noise=args.prompt_noise
+            )
+            loss_per_token = F.cross_entropy(
+                logits.transpose(1, 2), x_0, reduction="none"
+            )
+            if reduce:
+                loss += (loss_per_token * mask_generate).sum(dim=1)
+            else:
+                loss += loss_per_token * mask_generate
 
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
+        record.append(loss)
 
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
+    loss = torch.cat(record, dim=0)
 
-    def forward(self, bs):
-        if torch.is_tensor(bs):
-            return self.forward(BracketedSequence(bs)).x
-        bs = self.embedding(bs)
-        bs = self.positional_encoding(bs)
-        bs = self.trunk(bs)
-        bs = self.readout(bs)
-        return bs
+    if log_probas:
+        return -loss
+    else:
+        return (-loss).exp()
 
 
 ######################################################################
 
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
-    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
-    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
 
 def ae_batches(
     quiz_machine,
@@ -527,272 +464,6 @@ def NTC_masked_cross_entropy(output, targets, mask):
     return (loss_per_token * mask).mean()
 
 
-def NTC_channel_cat(*x):
-    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-def deterministic(mask_generate):
-    return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
-
-
-######################################################################
-
-torch.set_printoptions(
-    precision=None,
-    threshold=None,
-    edgeitems=None,
-    linewidth=500,
-    profile=None,
-    sci_mode=None,
-)
-
-N = quiz_machine.problem.nb_colors
-T = args.nb_diffusion_iterations + 1
-diffusion_M = torch.empty(T, N, N)
-diffusion_M[0] = torch.eye(N)
-
-# i >0 j>0
-# P(X'=0 | X=0) = 1-epsilon
-# P(X'=i | X=0) = epsilon/(N-1)
-# P(X'=0 | X=i) = delta
-# P(X'=X | X=i) = 1-epsilon-delta
-# P(X'=j | X=i) = epsilon/(N-2)
-
-diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
-diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
-diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta
-diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1)
-
-for k in range(1, N):
-    diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
-
-# m = diffusion_M[1]
-
-# print(m)
-# print(m.sum(dim=0))
-# print(torch.linalg.matrix_power(m, 25))
-
-# exit(0)
-
-for t in range(2, T):
-    # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1]
-    diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t)
-
-# p = torch.full((N,), 1 / N)
-
-# for t in range(diffusion_M.size(0)):
-# print(diffusion_M[t] @ p)
-
-# print(diffusion_M[T-1])
-
-# exit(0)
-
-#
-# Given x_0 and t_0, t_1, ..., returns
-#
-#    x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0)
-#
-
-
-def sample_x_t_given_x_0(x_0, t):
-    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
-    r = torch.rand(x_0.size(), device=x_0.device)
-    proba_erased = 1 - (1 - args.diffusion_delta) ** t
-    mask_erased = (r <= proba_erased[:, None]).long()
-    x_t = (1 - mask_erased) * x_0 + mask_erased * noise
-
-    return x_t
-
-
-# This function returns a 2d tensor of same shape as low, full of
-# uniform random values in [0,1], such that, in every row, the values
-# corresponding to the True in low are all lesser than the values
-# corresponding to the False.
-
-
-def prioritized_rand(low):
-    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
-    k = torch.rand(low.size(), device=low.device) + low.long()
-    k = k.sort(dim=1).indices
-    y = x.new(x.size())
-    y.scatter_(dim=1, index=k, src=x)
-    return y
-
-
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
-    r = prioritized_rand(x_0 != x_t)
-
-    mask_changes = (r <= args.diffusion_delta).long()
-
-    x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
-
-    return x_t_minus_1
-
-
-######################################################################
-# Non-uniform transitions, to be fixed?
-
-
-def ___sample_x_t_given_x_0(x_0, t):
-    D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
-    mask = (x_0 < quiz_machine.problem.nb_colors).long()
-    probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
-    dist = torch.distributions.categorical.Categorical(probs=probas)
-    x_t = (1 - mask) * x_0 + mask * dist.sample()
-    return x_t
-
-
-def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
-    mask = (x_0 < quiz_machine.problem.nb_colors).long()
-
-    # i = x_0[n,s], j = x_t[n,s]
-    # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]]
-
-    # A[n,s,k] = M[1,x_t[n,s],k]
-    # B[n,s,k] = M[t[n]-1,x_0[n,s],k]
-    # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]]
-    # probas = A * B / C
-
-    N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1)
-
-    _1 = x_0.new_full((N, S, K), 1)
-    _t = x_0.new_full((N, S, K), t)
-    _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K)
-    _x_t = (mask * x_t)[:, :, None].expand(N, S, K)
-    _x_0 = (mask * x_0)[:, :, None].expand(N, S, K)
-
-    M = diffusion_M.to(x_0.device)
-
-    probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t]
-
-    dist = torch.distributions.categorical.Categorical(probs=probas)
-    x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample()
-
-    return x_t_minus_1
-
-
-######################################################################
-
-# This function gets a clean target x_0, and a mask indicating which
-# part to generate (conditionnaly to the others), and returns the
-# logits starting from a x_t|X_0=x_0 picked at random with t random
-
-
-def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
-    # We favor iterations near the clean signal
-
-    probs_iterations = 0.1 ** torch.linspace(
-        0, 1, args.nb_diffusion_iterations, device=x_0.device
-    )
-
-    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-    probs_iterations = probs_iterations.expand(x_0.size(0), -1)
-    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
-    t = dist.sample() + 1
-
-    x_t = sample_x_t_given_x_0(x_0, t)
-
-    # Only the part to generate is degraded, the rest is a perfect
-    # noise-free conditionning
-
-    x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
-    # We may inject noise to prevent high-complexity non-structure
-    # signal to be generated as a way of "increasing reasoning
-    # complexity"
-
-    if prompt_noise > 0:
-        mask_prompt_noise = (
-            torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
-        ).long()
-        noise = torch.randint(
-            quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device
-        )
-        noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
-        x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
-
-    x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-
-    with torch.amp.autocast("cuda"):
-        logits_hat_x_0 = model(x_t_with_mask)
-
-    return logits_hat_x_0
-
-
-######################################################################
-
-
-def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
-    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
-
-    single_iteration = deterministic(mask_generate)[:, None]
-
-    if mask_hints is None:
-        mask_start = mask_generate
-    else:
-        mask_start = mask_generate * (1 - mask_hints)
-
-    x_t = (1 - mask_start) * x_0 + mask_start * noise
-
-    changed = True
-
-    for it in range(nb_iterations_max):
-        x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-        with torch.amp.autocast("cuda"):
-            logits = model(x_t_with_mask)
-        logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
-        dist = torch.distributions.categorical.Categorical(logits=logits)
-
-        hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
-
-        hat_x_t_minus_1 = single_iteration * hat_x_0 + (
-            1 - single_iteration
-        ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
-
-        if hat_x_t_minus_1.equal(x_t):
-            break
-        else:
-            changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
-            x_t[changed] = hat_x_t_minus_1[changed]
-
-    return x_t
-
-
-######################################################################
-
-
-def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
-    record = []
-
-    for x_0 in input.split(args.batch_size):
-        loss = 0
-
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-            )
-            logits = logits_hat_x_0_from_random_iteration(
-                model, x_0, mask_generate, prompt_noise=args.prompt_noise
-            )
-            loss_per_token = F.cross_entropy(
-                logits.transpose(1, 2), x_0, reduction="none"
-            )
-            if reduce:
-                loss += (loss_per_token * mask_generate).sum(dim=1)
-            else:
-                loss += loss_per_token * mask_generate
-
-        record.append(loss)
-
-    loss = torch.cat(record, dim=0)
-
-    if log_probas:
-        return -loss
-    else:
-        return (-loss).exp()
-
-
 ######################################################################
 
 
@@ -819,7 +490,11 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate)
+            logits = diffuser.logits_hat_x_0_from_random_iteration(
+                model=model,
+                x_0=x_0,
+                mask_generate=mask_generate,
+            )
             loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
             acc_test_loss += loss.item() * x_0.size(0)
             nb_test_samples += x_0.size(0)
@@ -840,7 +515,9 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate)
+            result = diffuser.ae_generate(
+                model, (1 - mask_generate) * x_0, mask_generate
+            )
             correct = (result == x_0).min(dim=1).values.long()
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
                 :, :, 1
@@ -888,7 +565,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    scaler = torch.amp.GradScaler("cuda")
+    scaler = torch.amp.GradScaler("cuda")
 
     for x_0, mask_generate in ae_batches(
         quiz_machine,
@@ -904,21 +581,29 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        with torch.amp.autocast("cuda"):
-            logits = logits_hat_x_0_from_random_iteration(
-                model, x_0, mask_generate, prompt_noise=args.prompt_noise
-            )
+        # with torch.amp.autocast("cuda"):
+        logits = diffuser.logits_hat_x_0_from_random_iteration(
+            model=model,
+            x_0=x_0,
+            mask_generate=mask_generate,
+            prompt_noise=args.prompt_noise,
+        )
 
         loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
         acc_train_loss += loss.item() * x_0.size(0)
         nb_train_samples += x_0.size(0)
 
-        scaler.scale(loss).backward()
+        loss.backward()
 
         if nb_train_samples % args.batch_size == 0:
-            scaler.step(model.optimizer)
+            model.optimizer.step()
+
+        # scaler.scale(loss).backward()
+
+        # if nb_train_samples % args.batch_size == 0:
+        # scaler.step(model.optimizer)
 
-        scaler.update()
+        scaler.update()
 
     log_string(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"