--- /dev/null
+#!/usr/bin/env python
+
+import math
+
+import torch, torchvision
+
+from torch import nn
+from torch.nn import functional as F
+
+
+def NTC_channel_cat(*x):
+ return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
+class Diffuser:
+ def __init__(self, mu_T_sampler, nb_iterations, proba_corruption):
+ self.mu_T_sampler = mu_T_sampler
+ self.nb_iterations = nb_iterations
+ self.proba_corruption = proba_corruption
+
+ def sample_x_t_given_x_0(self, x_0, t):
+ noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+ r = torch.rand(x_0.size(), device=x_0.device)
+ proba_erased = 1 - (1 - self.proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+ x_t = (1 - mask_erased) * x_0 + mask_erased * noise
+
+ return x_t
+
+ # This function returns a 2d tensor of same shape as low, full of
+ # uniform random values in [0,1], such that, in every row, the values
+ # corresponding to the True in low are all lesser than the values
+ # corresponding to the False.
+
+ def prioritized_rand(self, low):
+ x = (
+ torch.rand(low.size(), device=low.device)
+ .sort(dim=1, descending=True)
+ .values
+ )
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+ def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t):
+ r = self.prioritized_rand(x_0 != x_t)
+ mask_changes = (r <= self.proba_corruption).long()
+ x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
+ return x_t_minus_1
+
+ ######################################################################
+
+ # This function gets a clean target x_0, and a mask indicating which
+ # part to generate (conditionnaly to the others), and returns the
+ # logits starting from a x_t|X_0=x_0 picked at random with t random
+
+ def logits_hat_x_0_from_random_iteration(
+ self, model, x_0, mask_generate, prompt_noise=0.0
+ ):
+ # We favor iterations near the clean signal
+
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, self.nb_iterations, device=x_0.device
+ )
+
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(x_0.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+
+ t = dist.sample() + 1
+
+ x_t = self.sample_x_t_given_x_0(x_0, t)
+
+ # Only the part to generate is degraded, the rest is a perfect
+ # noise-free conditionning
+
+ x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
+
+ # We may inject noise to prevent high-complexity non-structure
+ # signal to be generated as a way of "increasing reasoning
+ # complexity"
+
+ if prompt_noise > 0:
+ mask_prompt_noise = (
+ torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
+ ).long()
+ noise = self.mu_T_sampler(x_t.size(), device=x_t.device)
+ noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
+ x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
+
+ x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+
+ # with torch.amp.autocast("cuda"):
+ logits_hat_x_0 = model(x_t_with_mask)
+
+ return logits_hat_x_0
+
+ ######################################################################
+
+ def ae_generate(self, model, x_0, mask_generate, mask_hints=None):
+ noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+
+ single_iteration = (
+ mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+ ).long()[:, None]
+
+ if mask_hints is None:
+ mask_start = mask_generate
+ else:
+ mask_start = mask_generate * (1 - mask_hints)
+
+ x_t = (1 - mask_start) * x_0 + mask_start * noise
+
+ changed = True
+
+ for it in range(self.nb_iterations):
+ x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+ # with torch.amp.autocast("cuda"):
+ logits = model(x_t_with_mask)
+ # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+
+ hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
+
+ hat_x_t_minus_1 = single_iteration * hat_x_0 + (
+ 1 - single_iteration
+ ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
+
+ if hat_x_t_minus_1.equal(x_t):
+ break
+ else:
+ changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
+ x_t[changed] = hat_x_t_minus_1[changed]
+
+ return x_t
import threading, subprocess
-import torch.multiprocessing as mp
+# import torch.multiprocessing as mp
-torch.set_float32_matmul_precision("high")
+# torch.set_float32_matmul_precision("high")
+
+import diffusion
######################################################################
parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
-parser.add_argument("--diffusion_delta", type=float, default=0.05)
-
-parser.add_argument("--diffusion_epsilon", type=float, default=0.05)
+parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05)
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
device=main_device,
)
+
+def mu_T_sampler(shape, device="cpu"):
+ return torch.randint(quiz_machine.problem.nb_colors, shape, device=device)
+
+
+diffuser = diffusion.Diffuser(
+ mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption
+)
+
######################################################################
log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
######################################################################
+# quad_order, quad_generate, quad_noise, quad_loss
-from mygpt import (
- CachedWithResidual,
- CacheWrapper,
- CachedVaswaniPositionalEncoding,
- QKVAttention,
- BracketedSequence,
-)
-
-
-class MultiEmbedding(nn.Module):
- def __init__(self, nb_values, dim):
- super().__init__()
- self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values])
-
- def forward(self, x):
- y = 0
- for f, z in zip(self.embeddings, x.split(1, dim=2)):
- y = y + f(z[:, :, 0])
- return y
-
-
-def attention_block(dim_model, dim_keys, nb_heads, dropout):
- return CachedWithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- )
-
-
-def ffw_block(dim_model, dim_hidden, nb_heads, dropout):
- return CachedWithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- )
-
-
-class MyAttentionAE(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- dropout=0.0,
- len_max=1024,
- ):
- super().__init__()
+data_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
- assert dim_model % nb_heads == 0
- self.embedding = CacheWrapper(
- nn.Sequential(
- MultiEmbedding((vocabulary_size, 2), dim_model),
- nn.Dropout(dropout),
- ),
- )
+######################################################################
- # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
- self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5)
- trunk_blocks = []
+def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
+ record = []
- for b in range(nb_blocks):
- trunk_blocks += [
- attention_block(dim_model, dim_keys, nb_heads, dropout),
- ffw_block(dim_model, dim_hidden, nb_heads, dropout),
- ]
+ for x_0 in input.split(args.batch_size):
+ loss = 0
- self.trunk = nn.Sequential(*trunk_blocks)
+ for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+ mask_generate = quiz_machine.make_quiz_mask(
+ quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+ )
+ logits = logits_hat_x_0_from_random_iteration(
+ model, x_0, mask_generate, prompt_noise=args.prompt_noise
+ )
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), x_0, reduction="none"
+ )
+ if reduce:
+ loss += (loss_per_token * mask_generate).sum(dim=1)
+ else:
+ loss += loss_per_token * mask_generate
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
+ record.append(loss)
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
+ loss = torch.cat(record, dim=0)
- def forward(self, bs):
- if torch.is_tensor(bs):
- return self.forward(BracketedSequence(bs)).x
- bs = self.embedding(bs)
- bs = self.positional_encoding(bs)
- bs = self.trunk(bs)
- bs = self.readout(bs)
- return bs
+ if log_probas:
+ return -loss
+ else:
+ return (-loss).exp()
######################################################################
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
def ae_batches(
quiz_machine,
return (loss_per_token * mask).mean()
-def NTC_channel_cat(*x):
- return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-def deterministic(mask_generate):
- return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
-
-
-######################################################################
-
-torch.set_printoptions(
- precision=None,
- threshold=None,
- edgeitems=None,
- linewidth=500,
- profile=None,
- sci_mode=None,
-)
-
-N = quiz_machine.problem.nb_colors
-T = args.nb_diffusion_iterations + 1
-diffusion_M = torch.empty(T, N, N)
-diffusion_M[0] = torch.eye(N)
-
-# i >0 j>0
-# P(X'=0 | X=0) = 1-epsilon
-# P(X'=i | X=0) = epsilon/(N-1)
-# P(X'=0 | X=i) = delta
-# P(X'=X | X=i) = 1-epsilon-delta
-# P(X'=j | X=i) = epsilon/(N-2)
-
-diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
-diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
-diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta
-diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1)
-
-for k in range(1, N):
- diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
-
-# m = diffusion_M[1]
-
-# print(m)
-# print(m.sum(dim=0))
-# print(torch.linalg.matrix_power(m, 25))
-
-# exit(0)
-
-for t in range(2, T):
- # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1]
- diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t)
-
-# p = torch.full((N,), 1 / N)
-
-# for t in range(diffusion_M.size(0)):
-# print(diffusion_M[t] @ p)
-
-# print(diffusion_M[T-1])
-
-# exit(0)
-
-#
-# Given x_0 and t_0, t_1, ..., returns
-#
-# x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0)
-#
-
-
-def sample_x_t_given_x_0(x_0, t):
- noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
- r = torch.rand(x_0.size(), device=x_0.device)
- proba_erased = 1 - (1 - args.diffusion_delta) ** t
- mask_erased = (r <= proba_erased[:, None]).long()
- x_t = (1 - mask_erased) * x_0 + mask_erased * noise
-
- return x_t
-
-
-# This function returns a 2d tensor of same shape as low, full of
-# uniform random values in [0,1], such that, in every row, the values
-# corresponding to the True in low are all lesser than the values
-# corresponding to the False.
-
-
-def prioritized_rand(low):
- x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
- k = torch.rand(low.size(), device=low.device) + low.long()
- k = k.sort(dim=1).indices
- y = x.new(x.size())
- y.scatter_(dim=1, index=k, src=x)
- return y
-
-
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
- r = prioritized_rand(x_0 != x_t)
-
- mask_changes = (r <= args.diffusion_delta).long()
-
- x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
-
- return x_t_minus_1
-
-
-######################################################################
-# Non-uniform transitions, to be fixed?
-
-
-def ___sample_x_t_given_x_0(x_0, t):
- D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
- mask = (x_0 < quiz_machine.problem.nb_colors).long()
- probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
- dist = torch.distributions.categorical.Categorical(probs=probas)
- x_t = (1 - mask) * x_0 + mask * dist.sample()
- return x_t
-
-
-def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
- mask = (x_0 < quiz_machine.problem.nb_colors).long()
-
- # i = x_0[n,s], j = x_t[n,s]
- # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]]
-
- # A[n,s,k] = M[1,x_t[n,s],k]
- # B[n,s,k] = M[t[n]-1,x_0[n,s],k]
- # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]]
- # probas = A * B / C
-
- N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1)
-
- _1 = x_0.new_full((N, S, K), 1)
- _t = x_0.new_full((N, S, K), t)
- _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K)
- _x_t = (mask * x_t)[:, :, None].expand(N, S, K)
- _x_0 = (mask * x_0)[:, :, None].expand(N, S, K)
-
- M = diffusion_M.to(x_0.device)
-
- probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t]
-
- dist = torch.distributions.categorical.Categorical(probs=probas)
- x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample()
-
- return x_t_minus_1
-
-
-######################################################################
-
-# This function gets a clean target x_0, and a mask indicating which
-# part to generate (conditionnaly to the others), and returns the
-# logits starting from a x_t|X_0=x_0 picked at random with t random
-
-
-def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
- # We favor iterations near the clean signal
-
- probs_iterations = 0.1 ** torch.linspace(
- 0, 1, args.nb_diffusion_iterations, device=x_0.device
- )
-
- probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
- probs_iterations = probs_iterations.expand(x_0.size(0), -1)
- dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
- t = dist.sample() + 1
-
- x_t = sample_x_t_given_x_0(x_0, t)
-
- # Only the part to generate is degraded, the rest is a perfect
- # noise-free conditionning
-
- x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
- # We may inject noise to prevent high-complexity non-structure
- # signal to be generated as a way of "increasing reasoning
- # complexity"
-
- if prompt_noise > 0:
- mask_prompt_noise = (
- torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
- ).long()
- noise = torch.randint(
- quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device
- )
- noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
- x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
-
- x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-
- with torch.amp.autocast("cuda"):
- logits_hat_x_0 = model(x_t_with_mask)
-
- return logits_hat_x_0
-
-
-######################################################################
-
-
-def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
- noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
-
- single_iteration = deterministic(mask_generate)[:, None]
-
- if mask_hints is None:
- mask_start = mask_generate
- else:
- mask_start = mask_generate * (1 - mask_hints)
-
- x_t = (1 - mask_start) * x_0 + mask_start * noise
-
- changed = True
-
- for it in range(nb_iterations_max):
- x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- with torch.amp.autocast("cuda"):
- logits = model(x_t_with_mask)
- logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
- dist = torch.distributions.categorical.Categorical(logits=logits)
-
- hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
-
- hat_x_t_minus_1 = single_iteration * hat_x_0 + (
- 1 - single_iteration
- ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
-
- if hat_x_t_minus_1.equal(x_t):
- break
- else:
- changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
- x_t[changed] = hat_x_t_minus_1[changed]
-
- return x_t
-
-
-######################################################################
-
-
-def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
- record = []
-
- for x_0 in input.split(args.batch_size):
- loss = 0
-
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
- logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
- )
- loss_per_token = F.cross_entropy(
- logits.transpose(1, 2), x_0, reduction="none"
- )
- if reduce:
- loss += (loss_per_token * mask_generate).sum(dim=1)
- else:
- loss += loss_per_token * mask_generate
-
- record.append(loss)
-
- loss = torch.cat(record, dim=0)
-
- if log_probas:
- return -loss
- else:
- return (-loss).exp()
-
-
######################################################################
c_quizzes=c_quizzes,
desc="test",
):
- logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate)
+ logits = diffuser.logits_hat_x_0_from_random_iteration(
+ model=model,
+ x_0=x_0,
+ mask_generate=mask_generate,
+ )
loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
acc_test_loss += loss.item() * x_0.size(0)
nb_test_samples += x_0.size(0)
c_quizzes=c_quizzes,
desc="test",
):
- result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate)
+ result = diffuser.ae_generate(
+ model, (1 - mask_generate) * x_0, mask_generate
+ )
correct = (result == x_0).min(dim=1).values.long()
predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
:, :, 1
nb_train_samples, acc_train_loss = 0, 0.0
- scaler = torch.amp.GradScaler("cuda")
+ # scaler = torch.amp.GradScaler("cuda")
for x_0, mask_generate in ae_batches(
quiz_machine,
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- with torch.amp.autocast("cuda"):
- logits = logits_hat_x_0_from_random_iteration(
- model, x_0, mask_generate, prompt_noise=args.prompt_noise
- )
+ # with torch.amp.autocast("cuda"):
+ logits = diffuser.logits_hat_x_0_from_random_iteration(
+ model=model,
+ x_0=x_0,
+ mask_generate=mask_generate,
+ prompt_noise=args.prompt_noise,
+ )
loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
acc_train_loss += loss.item() * x_0.size(0)
nb_train_samples += x_0.size(0)
- scaler.scale(loss).backward()
+ loss.backward()
if nb_train_samples % args.batch_size == 0:
- scaler.step(model.optimizer)
+ model.optimizer.step()
+
+ # scaler.scale(loss).backward()
+
+ # if nb_train_samples % args.batch_size == 0:
+ # scaler.step(model.optimizer)
- scaler.update()
+ # scaler.update()
log_string(
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"