Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 21:51:32 +0000 (23:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 21:51:32 +0000 (23:51 +0200)
main.py

diff --git a/main.py b/main.py
index 00f9175..b926f8e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,9 +101,9 @@ parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
 
-parser.add_argument("--diffusion_delta", type=float, default=0.1)
+parser.add_argument("--diffusion_delta", type=float, default=0.05)
 
-parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
+parser.add_argument("--diffusion_epsilon", type=float, default=0.05)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
@@ -802,8 +802,9 @@ diffusion_M[0] = torch.eye(N)
 
 diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
 diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
-diffusion_M[1, 0, 1:] = args.diffusion_delta
-diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2)
+diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta
+diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1)
+
 for k in range(1, N):
     diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
 
@@ -835,7 +836,7 @@ for t in range(2, T):
 #
 
 
-def sample_x_t_given_x_0_(x_0, t):
+def sample_x_t_given_x_0(x_0, t):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
     r = torch.rand(x_0.size(), device=x_0.device)
     proba_erased = 1 - (1 - args.diffusion_delta) ** t
@@ -845,7 +846,7 @@ def sample_x_t_given_x_0_(x_0, t):
     return x_t
 
 
-def sample_x_t_given_x_0(x_0, t):
+def ___sample_x_t_given_x_0(x_0, t):
     D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
     mask = (x_0 < quiz_machine.problem.nb_colors).long()
     probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
@@ -869,7 +870,7 @@ def prioritized_rand(low):
     return y
 
 
-def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t):
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
     r = prioritized_rand(x_0 != x_t)
 
     mask_changes = (r <= args.diffusion_delta).long()
@@ -879,7 +880,7 @@ def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t):
     return x_t_minus_1
 
 
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
+def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
     mask = (x_0 < quiz_machine.problem.nb_colors).long()
 
     # i = x_0[n,s], j = x_t[n,s]
@@ -972,9 +973,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
 
 def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
-    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
+    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    x_t = (1 - mask_generate) * x_0  # + mask_generate * noise
+    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
 
     one_iteration_prediction = deterministic(mask_generate)[:, None]