parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
-parser.add_argument("--diffusion_noise_proba", type=float, default=0.05)
+parser.add_argument("--diffusion_delta", type=float, default=0.05)
+
+parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
+######################################################################
+
+
# ------------------------------------------------------
alien_problem = grids.Grids(
max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
######################################################################
+N = quiz_machine.problem.nb_colors
+T = 50
+MP = torch.empty(T, N, N)
+MP[0] = torch.eye(N)
+MP[1, :, 0] = args.diffusion_epsilon / (N - 1)
+MP[1, 0, 0] = 1 - args.diffusion_epsilon
+MP[1, :, 1:] = args.diffusion_delta / (N - 1)
+for k in range(1, N):
+ MP[1, k, k] = 1 - args.diffusion_delta
+for t in range(2, T):
+ MP[t] = MP[1] @ MP[t]
+
#
# Given x_0 and t_0, t_1, ..., returns
#
#
-def sample_x_t_given_x_0(x_0, steps_nb_iterations):
+def sample_x_t_given_x_0(x_0, t):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
r = torch.rand(x_0.size(), device=x_0.device)
- result = []
-
- for n in steps_nb_iterations:
- proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
- mask_erased = (r <= proba_erased[:, None]).long()
- x = (1 - mask_erased) * x_0 + mask_erased * noise
- result.append(x)
+ proba_erased = 1 - (1 - args.diffusion_delta) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+ x_t = (1 - mask_erased) * x_0 + mask_erased * noise
- return result
+ return x_t
# This function returns a 2d tensor of same shape as low, full of
def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
r = prioritized_rand(x_0 != x_t)
- mask_changes = (r <= args.diffusion_noise_proba).long()
+ mask_changes = (r <= args.diffusion_delta).long()
x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
probs_iterations = probs_iterations.expand(x_0.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
- t_1 = dist.sample() + 1
+ t = dist.sample() + 1
- (x_t,) = sample_x_t_given_x_0(x_0, (t_1,))
+ x_t = sample_x_t_given_x_0(x_0, t)
# Only the part to generate is degraded, the rest is a perfect
# noise-free conditionning