Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:28:16 +0000 (09:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:28:16 +0000 (09:28 +0200)
diffusion.py [deleted file]
main.py

diff --git a/diffusion.py b/diffusion.py
deleted file mode 100755 (executable)
index 629113a..0000000
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/usr/bin/env python
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-
-def NTC_channel_cat(*x):
-    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
-
-
-class Diffuser:
-    def __init__(self, mu_T_sampler, nb_iterations, proba_corruption):
-        self.mu_T_sampler = mu_T_sampler
-        self.nb_iterations = nb_iterations
-        self.proba_corruption = proba_corruption
-
-    def sample_x_t_given_x_0(self, x_0, t):
-        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
-        r = torch.rand(x_0.size(), device=x_0.device)
-        proba_erased = 1 - (1 - self.proba_corruption) ** t
-        mask_erased = (r <= proba_erased[:, None]).long()
-        x_t = (1 - mask_erased) * x_0 + mask_erased * noise
-
-        return x_t
-
-    # This function returns a 2d tensor of same shape as low, full of
-    # uniform random values in [0,1], such that, in every row, the values
-    # corresponding to the True in low are all lesser than the values
-    # corresponding to the False.
-
-    def prioritized_rand(self, low):
-        x = (
-            torch.rand(low.size(), device=low.device)
-            .sort(dim=1, descending=True)
-            .values
-        )
-        k = torch.rand(low.size(), device=low.device) + low.long()
-        k = k.sort(dim=1).indices
-        y = x.new(x.size())
-        y.scatter_(dim=1, index=k, src=x)
-        return y
-
-    def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t):
-        r = self.prioritized_rand(x_0 != x_t)
-        mask_changes = (r <= self.proba_corruption).long()
-        x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
-        return x_t_minus_1
-
-    ######################################################################
-
-    def make_mask_hints(self, mask_generate, nb_hints):
-        if nb_hints is None:
-            mask_hints = torch.zeros(
-                mask_generate.size(),
-                device=mask_generate.device,
-                dtype=mask_generate.dtype,
-            )
-        else:
-            u = (
-                torch.rand(mask_generate.size(), device=mask_generate.device)
-                * mask_generate
-            )
-            v = u.sort(dim=1, descending=True).values.gather(
-                dim=1, index=nb_hints[:, None]
-            )
-            mask_hints = (u > v).long()
-
-        return mask_hints
-
-    # This function gets a clean target x_0, and a mask indicating which
-    # part to generate (conditionnaly to the others), and returns the
-    # logits starting from a x_t|X_0=x_0 picked at random with t random
-
-    def logits_hat_x_0_from_random_iteration(
-        self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0
-    ):
-        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
-
-        single_iteration = (
-            mask_generate.sum(dim=1) < mask_generate.size(1) // 2
-        ).long()[:, None]
-
-        mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration
-
-        # We favor iterations near the clean signal
-
-        probs_iterations = 0.1 ** torch.linspace(
-            0, 1, self.nb_iterations, device=x_0.device
-        )
-
-        probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-        probs_iterations = probs_iterations.expand(x_0.size(0), -1)
-        dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
-
-        t = dist.sample() + 1
-
-        x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
-        x_t = self.sample_x_t_given_x_0(x_0, t)
-        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
-        x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
-        # We may inject noise to prevent high-complexity non-structure
-        # signal to be generated as a way of "increasing reasoning
-        # complexity"
-
-        if prompt_noise > 0:
-            mask_prompt_noise = (
-                torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
-            ).long()
-            noise = self.mu_T_sampler(x_t.size(), device=x_t.device)
-            noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
-            x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
-
-        x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-
-        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits_hat_x_0 = model(x_t_with_mask)
-
-        return logits_hat_x_0
-
-    ######################################################################
-
-    def generate(self, model, x_0, mask_generate, nb_hints=None):
-        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
-
-        single_iteration = (
-            mask_generate.sum(dim=1) < mask_generate.size(1) // 2
-        ).long()[:, None]
-
-        mask_hints = self.make_mask_hints(mask_generate, nb_hints)
-
-        x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
-        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise
-        x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
-
-        changed = True
-
-        for it in range(self.nb_iterations):
-            x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
-            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-                logits = model(x_t_with_mask)
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-
-            hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
-
-            hat_x_t_minus_1 = single_iteration * hat_x_0 + (
-                1 - single_iteration
-            ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
-
-            if hat_x_t_minus_1.equal(x_t):
-                break
-            else:
-                changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
-                x_t[changed] = hat_x_t_minus_1[changed]
-
-        return x_t
diff --git a/main.py b/main.py
index 380be1e..772ef9f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -3,9 +3,6 @@
 # Any copyright is dedicated to the Public Domain.
 # https://creativecommons.org/publicdomain/zero/1.0/
 
-# > A > f(A) > B ; > f(B)
-# < f(B) ; < B < f(A) < A
-
 # Written by Francois Fleuret <francois@fleuret.org>
 
 import math, sys, argparse, time, tqdm, os, datetime, warnings, copy
@@ -29,8 +26,6 @@ import threading, subprocess
 
 # torch.set_default_dtype(torch.bfloat16)
 
-import diffusion
-
 ######################################################################
 
 parser = argparse.ArgumentParser(