X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=e7be8c1c8651a3cc4ee4a578586e2e4dd3c29bcf;hb=8d71d2f43cec159c7ca368c1dc4fa76f061d13b7;hp=7327522e1f13b08965cff1ddb91cad75c68e22fd;hpb=e8500a3f0cec4be59442e2b3bdbe692b04a9524a;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 7327522..e7be8c1 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -289,6 +289,9 @@ model.eval() if train_input.dim() == 2 and train_input.size(1) == 1: fig = plt.figure() + fig.set_figheight(5) + fig.set_figwidth(8) + ax = fig.add_subplot(1, 1, 1) x = generate((10000, 1), T, alpha, alpha_bar, sigma, @@ -310,12 +313,12 @@ if train_input.dim() == 2 and train_input.size(1) == 1: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'): - plt.get_current_fig_manager().window.setGeometry(2, 2, 2048, 768) + plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768) plt.show() ######################################## @@ -323,6 +326,9 @@ if train_input.dim() == 2 and train_input.size(1) == 1: elif train_input.dim() == 2 and train_input.size(1) == 2: fig = plt.figure() + fig.set_figheight(6) + fig.set_figwidth(6) + ax = fig.add_subplot(1, 1, 1) x = generate((1000, 2), T, alpha, alpha_bar, sigma, @@ -344,7 +350,7 @@ elif train_input.dim() == 2 and train_input.size(1) == 2: ax.legend(frameon = False, loc = 2) - filename = f'diffusion_{args.data}.pdf' + filename = f'minidiffusion_{args.data}.pdf' print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') @@ -369,7 +375,7 @@ elif train_input.dim() == 4: result = 1 - torch.cat((t, x), 2) / 255 - filename = f'diffusion_{args.data}.png' + filename = f'minidiffusion_{args.data}.png' print(f'saving {filename}') torchvision.utils.save_image(result, filename)