From 82c200881c0ede5931d14ac4d36ef0b35b55be6f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 14:55:45 +0200 Subject: [PATCH] Update. --- main.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 21c609c..150010f 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,9 @@ parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--nb_diffusion_iterations", type=int, default=25) -parser.add_argument("--diffusion_noise_proba", type=float, default=0.05) +parser.add_argument("--diffusion_delta", type=float, default=0.05) + +parser.add_argument("--diffusion_epsilon", type=float, default=0.01) parser.add_argument("--min_succeed_to_validate", type=int, default=2) @@ -284,6 +286,9 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 +###################################################################### + + # ------------------------------------------------------ alien_problem = grids.Grids( max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, @@ -774,6 +779,18 @@ def deterministic(mask_generate): ###################################################################### +N = quiz_machine.problem.nb_colors +T = 50 +MP = torch.empty(T, N, N) +MP[0] = torch.eye(N) +MP[1, :, 0] = args.diffusion_epsilon / (N - 1) +MP[1, 0, 0] = 1 - args.diffusion_epsilon +MP[1, :, 1:] = args.diffusion_delta / (N - 1) +for k in range(1, N): + MP[1, k, k] = 1 - args.diffusion_delta +for t in range(2, T): + MP[t] = MP[1] @ MP[t] + # # Given x_0 and t_0, t_1, ..., returns # @@ -781,20 +798,16 @@ def deterministic(mask_generate): # -def sample_x_t_given_x_0(x_0, steps_nb_iterations): +def sample_x_t_given_x_0(x_0, t): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) r = torch.rand(x_0.size(), device=x_0.device) - result = [] - - for n in steps_nb_iterations: - proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n - mask_erased = (r <= proba_erased[:, None]).long() - x = (1 - mask_erased) * x_0 + mask_erased * noise - result.append(x) + proba_erased = 1 - (1 - args.diffusion_delta) ** t + mask_erased = (r <= proba_erased[:, None]).long() + x_t = (1 - mask_erased) * x_0 + mask_erased * noise - return result + return x_t # This function returns a 2d tensor of same shape as low, full of @@ -815,7 +828,7 @@ def prioritized_rand(low): def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): r = prioritized_rand(x_0 != x_t) - mask_changes = (r <= args.diffusion_noise_proba).long() + mask_changes = (r <= args.diffusion_delta).long() x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 @@ -840,9 +853,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise probs_iterations = probs_iterations.expand(x_0.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - t_1 = dist.sample() + 1 + t = dist.sample() + 1 - (x_t,) = sample_x_t_given_x_0(x_0, (t_1,)) + x_t = sample_x_t_given_x_0(x_0, t) # Only the part to generate is degraded, the rest is a perfect # noise-free conditionning -- 2.39.5