parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
-parser.add_argument("--diffusion_delta", type=float, default=0.05)
+parser.add_argument("--diffusion_delta", type=float, default=0.1)
parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
######################################################################
+torch.set_printoptions(
+ precision=None,
+ threshold=None,
+ edgeitems=None,
+ linewidth=500,
+ profile=None,
+ sci_mode=None,
+)
+
N = quiz_machine.problem.nb_colors
-T = 50
-MP = torch.empty(T, N, N)
-MP[0] = torch.eye(N)
-MP[1, :, 0] = args.diffusion_epsilon / (N - 1)
-MP[1, 0, 0] = 1 - args.diffusion_epsilon
-MP[1, :, 1:] = args.diffusion_delta / (N - 1)
+T = args.nb_diffusion_iterations + 1
+diffusion_M = torch.empty(T, N, N)
+diffusion_M[0] = torch.eye(N)
+
+# i >0 j>0
+# P(X'=0 | X=0) = 1-epsilon
+# P(X'=i | X=0) = epsilon/(N-1)
+# P(X'=0 | X=i) = delta
+# P(X'=X | X=i) = 1-epsilon-delta
+# P(X'=j | X=i) = epsilon/(N-2)
+
+diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
+diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
+diffusion_M[1, 0, 1:] = args.diffusion_delta
+diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2)
for k in range(1, N):
- MP[1, k, k] = 1 - args.diffusion_delta
+ diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
+
+# m = diffusion_M[1]
+
+# print(m)
+# print(m.sum(dim=0))
+# print(torch.linalg.matrix_power(m, 25))
+
+# exit(0)
+
for t in range(2, T):
- MP[t] = MP[1] @ MP[t]
+ # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1]
+ diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t)
+
+# p = torch.full((N,), 1 / N)
+
+# for t in range(diffusion_M.size(0)):
+# print(diffusion_M[t] @ p)
+
+# print(diffusion_M[T-1])
+
+# exit(0)
#
# Given x_0 and t_0, t_1, ..., returns
#
-def sample_x_t_given_x_0(x_0, t):
+def sample_x_t_given_x_0_(x_0, t):
noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
-
r = torch.rand(x_0.size(), device=x_0.device)
-
proba_erased = 1 - (1 - args.diffusion_delta) ** t
mask_erased = (r <= proba_erased[:, None]).long()
x_t = (1 - mask_erased) * x_0 + mask_erased * noise
return x_t
+def sample_x_t_given_x_0(x_0, t):
+ D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
+ mask = (x_0 < quiz_machine.problem.nb_colors).long()
+ probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
+ dist = torch.distributions.categorical.Categorical(probs=probas)
+ x_t = (1 - mask) * x_0 + mask * dist.sample()
+ return x_t
+
+
# This function returns a 2d tensor of same shape as low, full of
# uniform random values in [0,1], such that, in every row, the values
# corresponding to the True in low are all lesser than the values
return y
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
+def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t):
r = prioritized_rand(x_0 != x_t)
mask_changes = (r <= args.diffusion_delta).long()
return x_t_minus_1
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
+ mask = (x_0 < quiz_machine.problem.nb_colors).long()
+
+ # i = x_0[n,s], j = x_t[n,s]
+ # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]]
+
+ # A[n,s,k] = M[1,x_t[n,s],k]
+ # B[n,s,k] = M[t[n]-1,x_0[n,s],k]
+ # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]]
+ # probas = A * B / C
+
+ N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1)
+
+ _1 = x_0.new_full((N, S, K), 1)
+ _t = x_0.new_full((N, S, K), t)
+ _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K)
+ _x_t = (mask * x_t)[:, :, None].expand(N, S, K)
+ _x_0 = (mask * x_0)[:, :, None].expand(N, S, K)
+
+ M = diffusion_M.to(x_0.device)
+
+ probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t]
+
+ dist = torch.distributions.categorical.Categorical(probs=probas)
+ x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample()
+
+ return x_t_minus_1
+
+
######################################################################
# This function gets a clean target x_0, and a mask indicating which
x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # filename = f"debug.png"
+
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename,
+ # quizzes=x_t,
+ # )
+
+ # log_string(f"wrote {filename}")
+ # exit(0)
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
# We may inject noise to prevent high-complexity non-structure
# signal to be generated as a way of "increasing reasoning
# complexity"
def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
- noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
+ # noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+ x_t = (1 - mask_generate) * x_0 # + mask_generate * noise
one_iteration_prediction = deterministic(mask_generate)[:, None]
for it in range(nb_iterations_max):
x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
logits = model(x_t_with_mask)
+ logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
dist = torch.distributions.categorical.Categorical(logits=logits)
hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
1 - one_iteration_prediction
- ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
+ ) * sample_x_t_minus_1_given_x_0_x_t(
+ hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it)
+ )
if hat_x_t_minus_1.equal(x_t):
# log_string(f"exit after {it+1} iterations")
# Save some images
for f, record in [("prediction", record_d), ("generation", record_nd)]:
- filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
-
result, predicted_parts, correct_parts = bag_to_tensors(record)
+ filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
# exit(0)
- # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device)
+ # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device)
# exit(0)
log_string(f"{time_train=} {time_c_quizzes=}")