X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=066cbbbe1fa365458ea163bea0e64e1cd2787c74;hb=87d376ff7929347865b199d05d003ab3b168f249;hp=c88765ce2cbe0a92647081a1a60a3b4954a33e56;hpb=b19c2b7ddf3e4db73d422c3c7e6c4371f9d6e657;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index c88765c..066cbbb 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -66,12 +66,14 @@ def sample_mnist(nb): return result samplers = { - 'gaussian_mixture': sample_gaussian_mixture, - 'ramp': sample_ramp, - 'two_discs': sample_two_discs, - 'disc_grid': sample_disc_grid, - 'spiral': sample_spiral, - 'mnist': sample_mnist, + f.__name__.removeprefix('sample_') : f for f in [ + sample_gaussian_mixture, + sample_ramp, + sample_two_discs, + sample_disc_grid, + sample_spiral, + sample_mnist, + ] } ###################################################################### @@ -313,7 +315,7 @@ if train_input.dim() == 2 and train_input.size(1) == 1: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') @@ -350,7 +352,7 @@ elif train_input.dim() == 2 and train_input.size(1) == 2: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') @@ -375,7 +377,7 @@ elif train_input.dim() == 4: result = 1 - torch.cat((t, x), 2) / 255 - filename = f'diffusion_{args.data}.png' + filename = f'minidiffusion_{args.data}.png' print(f'saving {filename}') torchvision.utils.save_image(result, filename)