+++ /dev/null
-#!/usr/bin/env python
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-
-def NTC_channel_cat(*x):
- return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-class Diffuser:
- def __init__(self, mu_T_sampler, nb_iterations, proba_corruption):
- self.mu_T_sampler = mu_T_sampler
- self.nb_iterations = nb_iterations
- self.proba_corruption = proba_corruption
-
- def sample_x_t_given_x_0(self, x_0, t):
- noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
- r = torch.rand(x_0.size(), device=x_0.device)
- proba_erased = 1 - (1 - self.proba_corruption) ** t
- mask_erased = (r <= proba_erased[:, None]).long()
- x_t = (1 - mask_erased) * x_0 + mask_erased * noise
-
- return x_t
-
- # This function returns a 2d tensor of same shape as low, full of
- # uniform random values in [0,1], such that, in every row, the values
- # corresponding to the True in low are all lesser than the values
- # corresponding to the False.
-
- def prioritized_rand(self, low):
- x = (
- torch.rand(low.size(), device=low.device)
- .sort(dim=1, descending=True)
- .values
- )
- k = torch.rand(low.size(), device=low.device) + low.long()
- k = k.sort(dim=1).indices
- y = x.new(x.size())
- y.scatter_(dim=1, index=k, src=x)
- return y
-
- def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t):
- r = self.prioritized_rand(x_0 != x_t)
- mask_changes = (r <= self.proba_corruption).long()
- x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
- return x_t_minus_1
-
- ######################################################################
-
- def make_mask_hints(self, mask_generate, nb_hints):
- if nb_hints is None:
- mask_hints = torch.zeros(
- mask_generate.size(),
- device=mask_generate.device,
- dtype=mask_generate.dtype,
- )
- else:
- u = (
- torch.rand(mask_generate.size(), device=mask_generate.device)
- * mask_generate
- )
- v = u.sort(dim=1, descending=True).values.gather(
- dim=1, index=nb_hints[:, None]
- )
- mask_hints = (u > v).long()
-
- return mask_hints
-
- # This function gets a clean target x_0, and a mask indicating which
- # part to generate (conditionnaly to the others), and returns the
- # logits starting from a x_t|X_0=x_0 picked at random with t random
-
- def logits_hat_x_0_from_random_iteration(
- self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0
- ):
- noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
-
- single_iteration = (
- mask_generate.sum(dim=1) < mask_generate.size(1) // 2
- ).long()[:, None]
-
- mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration
-
- # We favor iterations near the clean signal
-
- probs_iterations = 0.1 ** torch.linspace(
- 0, 1, self.nb_iterations, device=x_0.device
- )
-
- probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
- probs_iterations = probs_iterations.expand(x_0.size(0), -1)
- dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
- t = dist.sample() + 1
-
- x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
- x_t = self.sample_x_t_given_x_0(x_0, t)
- x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
- x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
- # We may inject noise to prevent high-complexity non-structure
- # signal to be generated as a way of "increasing reasoning
- # complexity"
-
- if prompt_noise > 0:
- mask_prompt_noise = (
- torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
- ).long()
- noise = self.mu_T_sampler(x_t.size(), device=x_t.device)
- noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
- x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
-
- x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits_hat_x_0 = model(x_t_with_mask)
-
- return logits_hat_x_0
-
- ######################################################################
-
- def generate(self, model, x_0, mask_generate, nb_hints=None):
- noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
-
- single_iteration = (
- mask_generate.sum(dim=1) < mask_generate.size(1) // 2
- ).long()[:, None]
-
- mask_hints = self.make_mask_hints(mask_generate, nb_hints)
-
- x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
- x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise
- x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
- changed = True
-
- for it in range(self.nb_iterations):
- x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(x_t_with_mask)
- dist = torch.distributions.categorical.Categorical(logits=logits)
-
- hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
-
- hat_x_t_minus_1 = single_iteration * hat_x_0 + (
- 1 - single_iteration
- ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
-
- if hat_x_t_minus_1.equal(x_t):
- break
- else:
- changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
- x_t[changed] = hat_x_t_minus_1[changed]
-
- return x_t