X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=066cbbbe1fa365458ea163bea0e64e1cd2787c74;hb=88302d2fbe5ca4adc72f24a78e8d2abb5326418c;hp=7327522e1f13b08965cff1ddb91cad75c68e22fd;hpb=e8500a3f0cec4be59442e2b3bdbe692b04a9524a;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 7327522..066cbbb 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -66,12 +66,14 @@ def sample_mnist(nb): return result samplers = { - 'gaussian_mixture': sample_gaussian_mixture, - 'ramp': sample_ramp, - 'two_discs': sample_two_discs, - 'disc_grid': sample_disc_grid, - 'spiral': sample_spiral, - 'mnist': sample_mnist, + f.__name__.removeprefix('sample_') : f for f in [ + sample_gaussian_mixture, + sample_ramp, + sample_two_discs, + sample_disc_grid, + sample_spiral, + sample_mnist, + ] } ###################################################################### @@ -289,6 +291,9 @@ model.eval() if train_input.dim() == 2 and train_input.size(1) == 1: fig = plt.figure() + fig.set_figheight(5) + fig.set_figwidth(8) + ax = fig.add_subplot(1, 1, 1) x = generate((10000, 1), T, alpha, alpha_bar, sigma, @@ -310,12 +315,12 @@ if train_input.dim() == 2 and train_input.size(1) == 1: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'): - plt.get_current_fig_manager().window.setGeometry(2, 2, 2048, 768) + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() ######################################## @@ -323,6 +328,9 @@ if train_input.dim() == 2 and train_input.size(1) == 1: elif train_input.dim() == 2 and train_input.size(1) == 2: fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(6) + ax = fig.add_subplot(1, 1, 1) x = generate((1000, 2), T, alpha, alpha_bar, sigma, @@ -344,7 +352,7 @@ elif train_input.dim() == 2 and train_input.size(1) == 2: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') @@ -369,7 +377,7 @@ elif train_input.dim() == 4: result = 1 - torch.cat((t, x), 2) / 255 - filename = f'diffusion_{args.data}.png' + filename = f'minidiffusion_{args.data}.png' print(f'saving {filename}') torchvision.utils.save_image(result, filename)