From 95a2bb0ee9928b557d60ff132c21f84f00947c5b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 15 Sep 2024 11:30:14 +0200 Subject: [PATCH] Update. --- diffusion.py | 137 +++++++++++++++ main.py | 457 ++++++++------------------------------------------- 2 files changed, 208 insertions(+), 386 deletions(-) create mode 100755 diffusion.py diff --git a/diffusion.py b/diffusion.py new file mode 100755 index 0000000..98d8d0a --- /dev/null +++ b/diffusion.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + + +def NTC_channel_cat(*x): + return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) + + +class Diffuser: + def __init__(self, mu_T_sampler, nb_iterations, proba_corruption): + self.mu_T_sampler = mu_T_sampler + self.nb_iterations = nb_iterations + self.proba_corruption = proba_corruption + + def sample_x_t_given_x_0(self, x_0, t): + noise = self.mu_T_sampler(x_0.size(), device=x_0.device) + r = torch.rand(x_0.size(), device=x_0.device) + proba_erased = 1 - (1 - self.proba_corruption) ** t + mask_erased = (r <= proba_erased[:, None]).long() + x_t = (1 - mask_erased) * x_0 + mask_erased * noise + + return x_t + + # This function returns a 2d tensor of same shape as low, full of + # uniform random values in [0,1], such that, in every row, the values + # corresponding to the True in low are all lesser than the values + # corresponding to the False. + + def prioritized_rand(self, low): + x = ( + torch.rand(low.size(), device=low.device) + .sort(dim=1, descending=True) + .values + ) + k = torch.rand(low.size(), device=low.device) + low.long() + k = k.sort(dim=1).indices + y = x.new(x.size()) + y.scatter_(dim=1, index=k, src=x) + return y + + def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t): + r = self.prioritized_rand(x_0 != x_t) + mask_changes = (r <= self.proba_corruption).long() + x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 + return x_t_minus_1 + + ###################################################################### + + # This function gets a clean target x_0, and a mask indicating which + # part to generate (conditionnaly to the others), and returns the + # logits starting from a x_t|X_0=x_0 picked at random with t random + + def logits_hat_x_0_from_random_iteration( + self, model, x_0, mask_generate, prompt_noise=0.0 + ): + # We favor iterations near the clean signal + + probs_iterations = 0.1 ** torch.linspace( + 0, 1, self.nb_iterations, device=x_0.device + ) + + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + probs_iterations = probs_iterations.expand(x_0.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=probs_iterations) + + t = dist.sample() + 1 + + x_t = self.sample_x_t_given_x_0(x_0, t) + + # Only the part to generate is degraded, the rest is a perfect + # noise-free conditionning + + x_t = (1 - mask_generate) * x_0 + mask_generate * x_t + + # We may inject noise to prevent high-complexity non-structure + # signal to be generated as a way of "increasing reasoning + # complexity" + + if prompt_noise > 0: + mask_prompt_noise = ( + torch.rand(x_t.size(), device=x_t.device) <= prompt_noise + ).long() + noise = self.mu_T_sampler(x_t.size(), device=x_t.device) + noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise + x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t + + x_t_with_mask = NTC_channel_cat(x_t, mask_generate) + + # with torch.amp.autocast("cuda"): + logits_hat_x_0 = model(x_t_with_mask) + + return logits_hat_x_0 + + ###################################################################### + + def ae_generate(self, model, x_0, mask_generate, mask_hints=None): + noise = self.mu_T_sampler(x_0.size(), device=x_0.device) + + single_iteration = ( + mask_generate.sum(dim=1) < mask_generate.size(1) // 2 + ).long()[:, None] + + if mask_hints is None: + mask_start = mask_generate + else: + mask_start = mask_generate * (1 - mask_hints) + + x_t = (1 - mask_start) * x_0 + mask_start * noise + + changed = True + + for it in range(self.nb_iterations): + x_t_with_mask = NTC_channel_cat(x_t, mask_generate) + # with torch.amp.autocast("cuda"): + logits = model(x_t_with_mask) + # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") + dist = torch.distributions.categorical.Categorical(logits=logits) + + hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() + + hat_x_t_minus_1 = single_iteration * hat_x_0 + ( + 1 - single_iteration + ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) + + if hat_x_t_minus_1.equal(x_t): + break + else: + changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values + x_t[changed] = hat_x_t_minus_1[changed] + + return x_t diff --git a/main.py b/main.py index 01ce963..534bab9 100755 --- a/main.py +++ b/main.py @@ -23,9 +23,11 @@ from quiz_machine import one_batch_masked_inplace_autoregression import threading, subprocess -import torch.multiprocessing as mp +# import torch.multiprocessing as mp -torch.set_float32_matmul_precision("high") +# torch.set_float32_matmul_precision("high") + +import diffusion ###################################################################### @@ -107,9 +109,7 @@ parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--nb_diffusion_iterations", type=int, default=25) -parser.add_argument("--diffusion_delta", type=float, default=0.05) - -parser.add_argument("--diffusion_epsilon", type=float, default=0.05) +parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05) parser.add_argument("--min_succeed_to_validate", type=int, default=2) @@ -322,6 +322,15 @@ quiz_machine = quiz_machine.QuizMachine( device=main_device, ) + +def mu_T_sampler(shape, device="cpu"): + return torch.randint(quiz_machine.problem.nb_colors, shape, device=device) + + +diffuser = diffusion.Diffuser( + mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption +) + ###################################################################### log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") @@ -363,125 +372,53 @@ def optimizer_to(optim, device): ###################################################################### +# quad_order, quad_generate, quad_noise, quad_loss -from mygpt import ( - CachedWithResidual, - CacheWrapper, - CachedVaswaniPositionalEncoding, - QKVAttention, - BracketedSequence, -) - - -class MultiEmbedding(nn.Module): - def __init__(self, nb_values, dim): - super().__init__() - self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values]) - - def forward(self, x): - y = 0 - for f, z in zip(self.embeddings, x.split(1, dim=2)): - y = y + f(z[:, :, 0]) - return y - - -def attention_block(dim_model, dim_keys, nb_heads, dropout): - return CachedWithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ) - - -def ffw_block(dim_model, dim_hidden, nb_heads, dropout): - return CachedWithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ) - - -class MyAttentionAE(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - dropout=0.0, - len_max=1024, - ): - super().__init__() +data_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), + (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), +] - assert dim_model % nb_heads == 0 - self.embedding = CacheWrapper( - nn.Sequential( - MultiEmbedding((vocabulary_size, 2), dim_model), - nn.Dropout(dropout), - ), - ) +###################################################################### - # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max) - self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5) - trunk_blocks = [] +def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): + record = [] - for b in range(nb_blocks): - trunk_blocks += [ - attention_block(dim_model, dim_keys, nb_heads, dropout), - ffw_block(dim_model, dim_hidden, nb_heads, dropout), - ] + for x_0 in input.split(args.batch_size): + loss = 0 - self.trunk = nn.Sequential(*trunk_blocks) + for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: + mask_generate = quiz_machine.make_quiz_mask( + quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + ) + logits = logits_hat_x_0_from_random_iteration( + model, x_0, mask_generate, prompt_noise=args.prompt_noise + ) + loss_per_token = F.cross_entropy( + logits.transpose(1, 2), x_0, reduction="none" + ) + if reduce: + loss += (loss_per_token * mask_generate).sum(dim=1) + else: + loss += loss_per_token * mask_generate - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) + record.append(loss) - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) + loss = torch.cat(record, dim=0) - def forward(self, bs): - if torch.is_tensor(bs): - return self.forward(BracketedSequence(bs)).x - bs = self.embedding(bs) - bs = self.positional_encoding(bs) - bs = self.trunk(bs) - bs = self.readout(bs) - return bs + if log_probas: + return -loss + else: + return (-loss).exp() ###################################################################### -# quad_order, quad_generate, quad_noise, quad_loss - -data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), - (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), -] - def ae_batches( quiz_machine, @@ -527,272 +464,6 @@ def NTC_masked_cross_entropy(output, targets, mask): return (loss_per_token * mask).mean() -def NTC_channel_cat(*x): - return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) - - -def deterministic(mask_generate): - return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() - - -###################################################################### - -torch.set_printoptions( - precision=None, - threshold=None, - edgeitems=None, - linewidth=500, - profile=None, - sci_mode=None, -) - -N = quiz_machine.problem.nb_colors -T = args.nb_diffusion_iterations + 1 -diffusion_M = torch.empty(T, N, N) -diffusion_M[0] = torch.eye(N) - -# i >0 j>0 -# P(X'=0 | X=0) = 1-epsilon -# P(X'=i | X=0) = epsilon/(N-1) -# P(X'=0 | X=i) = delta -# P(X'=X | X=i) = 1-epsilon-delta -# P(X'=j | X=i) = epsilon/(N-2) - -diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon -diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1) -diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta -diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1) - -for k in range(1, N): - diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon - -# m = diffusion_M[1] - -# print(m) -# print(m.sum(dim=0)) -# print(torch.linalg.matrix_power(m, 25)) - -# exit(0) - -for t in range(2, T): - # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1] - diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t) - -# p = torch.full((N,), 1 / N) - -# for t in range(diffusion_M.size(0)): -# print(diffusion_M[t] @ p) - -# print(diffusion_M[T-1]) - -# exit(0) - -# -# Given x_0 and t_0, t_1, ..., returns -# -# x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0) -# - - -def sample_x_t_given_x_0(x_0, t): - noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - r = torch.rand(x_0.size(), device=x_0.device) - proba_erased = 1 - (1 - args.diffusion_delta) ** t - mask_erased = (r <= proba_erased[:, None]).long() - x_t = (1 - mask_erased) * x_0 + mask_erased * noise - - return x_t - - -# This function returns a 2d tensor of same shape as low, full of -# uniform random values in [0,1], such that, in every row, the values -# corresponding to the True in low are all lesser than the values -# corresponding to the False. - - -def prioritized_rand(low): - x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values - k = torch.rand(low.size(), device=low.device) + low.long() - k = k.sort(dim=1).indices - y = x.new(x.size()) - y.scatter_(dim=1, index=k, src=x) - return y - - -def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): - r = prioritized_rand(x_0 != x_t) - - mask_changes = (r <= args.diffusion_delta).long() - - x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 - - return x_t_minus_1 - - -###################################################################### -# Non-uniform transitions, to be fixed? - - -def ___sample_x_t_given_x_0(x_0, t): - D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device) - mask = (x_0 < quiz_machine.problem.nb_colors).long() - probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1))) - dist = torch.distributions.categorical.Categorical(probs=probas) - x_t = (1 - mask) * x_0 + mask * dist.sample() - return x_t - - -def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t): - mask = (x_0 < quiz_machine.problem.nb_colors).long() - - # i = x_0[n,s], j = x_t[n,s] - # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]] - - # A[n,s,k] = M[1,x_t[n,s],k] - # B[n,s,k] = M[t[n]-1,x_0[n,s],k] - # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]] - # probas = A * B / C - - N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1) - - _1 = x_0.new_full((N, S, K), 1) - _t = x_0.new_full((N, S, K), t) - _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K) - _x_t = (mask * x_t)[:, :, None].expand(N, S, K) - _x_0 = (mask * x_0)[:, :, None].expand(N, S, K) - - M = diffusion_M.to(x_0.device) - - probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t] - - dist = torch.distributions.categorical.Categorical(probs=probas) - x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample() - - return x_t_minus_1 - - -###################################################################### - -# This function gets a clean target x_0, and a mask indicating which -# part to generate (conditionnaly to the others), and returns the -# logits starting from a x_t|X_0=x_0 picked at random with t random - - -def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0): - # We favor iterations near the clean signal - - probs_iterations = 0.1 ** torch.linspace( - 0, 1, args.nb_diffusion_iterations, device=x_0.device - ) - - probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - probs_iterations = probs_iterations.expand(x_0.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - - t = dist.sample() + 1 - - x_t = sample_x_t_given_x_0(x_0, t) - - # Only the part to generate is degraded, the rest is a perfect - # noise-free conditionning - - x_t = (1 - mask_generate) * x_0 + mask_generate * x_t - - # We may inject noise to prevent high-complexity non-structure - # signal to be generated as a way of "increasing reasoning - # complexity" - - if prompt_noise > 0: - mask_prompt_noise = ( - torch.rand(x_t.size(), device=x_t.device) <= prompt_noise - ).long() - noise = torch.randint( - quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device - ) - noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise - x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t - - x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - - with torch.amp.autocast("cuda"): - logits_hat_x_0 = model(x_t_with_mask) - - return logits_hat_x_0 - - -###################################################################### - - -def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None): - noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - - single_iteration = deterministic(mask_generate)[:, None] - - if mask_hints is None: - mask_start = mask_generate - else: - mask_start = mask_generate * (1 - mask_hints) - - x_t = (1 - mask_start) * x_0 + mask_start * noise - - changed = True - - for it in range(nb_iterations_max): - x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - with torch.amp.autocast("cuda"): - logits = model(x_t_with_mask) - logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") - dist = torch.distributions.categorical.Categorical(logits=logits) - - hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - - hat_x_t_minus_1 = single_iteration * hat_x_0 + ( - 1 - single_iteration - ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) - - if hat_x_t_minus_1.equal(x_t): - break - else: - changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values - x_t[changed] = hat_x_t_minus_1[changed] - - return x_t - - -###################################################################### - - -def model_ae_proba_solutions(model, input, log_probas=False, reduce=True): - record = [] - - for x_0 in input.split(args.batch_size): - loss = 0 - - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise - ) - loss_per_token = F.cross_entropy( - logits.transpose(1, 2), x_0, reduction="none" - ) - if reduce: - loss += (loss_per_token * mask_generate).sum(dim=1) - else: - loss += loss_per_token * mask_generate - - record.append(loss) - - loss = torch.cat(record, dim=0) - - if log_probas: - return -loss - else: - return (-loss).exp() - - ###################################################################### @@ -819,7 +490,11 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate) + logits = diffuser.logits_hat_x_0_from_random_iteration( + model=model, + x_0=x_0, + mask_generate=mask_generate, + ) loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) acc_test_loss += loss.item() * x_0.size(0) nb_test_samples += x_0.size(0) @@ -840,7 +515,9 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate) + result = diffuser.ae_generate( + model, (1 - mask_generate) * x_0, mask_generate + ) correct = (result == x_0).min(dim=1).values.long() predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[ :, :, 1 @@ -888,7 +565,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi nb_train_samples, acc_train_loss = 0, 0.0 - scaler = torch.amp.GradScaler("cuda") + # scaler = torch.amp.GradScaler("cuda") for x_0, mask_generate in ae_batches( quiz_machine, @@ -904,21 +581,29 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - with torch.amp.autocast("cuda"): - logits = logits_hat_x_0_from_random_iteration( - model, x_0, mask_generate, prompt_noise=args.prompt_noise - ) + # with torch.amp.autocast("cuda"): + logits = diffuser.logits_hat_x_0_from_random_iteration( + model=model, + x_0=x_0, + mask_generate=mask_generate, + prompt_noise=args.prompt_noise, + ) loss = NTC_masked_cross_entropy(logits, x_0, mask_generate) acc_train_loss += loss.item() * x_0.size(0) nb_train_samples += x_0.size(0) - scaler.scale(loss).backward() + loss.backward() if nb_train_samples % args.batch_size == 0: - scaler.step(model.optimizer) + model.optimizer.step() + + # scaler.scale(loss).backward() + + # if nb_train_samples % args.batch_size == 0: + # scaler.step(model.optimizer) - scaler.update() + # scaler.update() log_string( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" -- 2.39.5