Update.
[pytorch.git] / minidiffusion.py
index 7327522..066cbbb 100755 (executable)
@@ -66,12 +66,14 @@ def sample_mnist(nb):
     return result
 
 samplers = {
-    'gaussian_mixture': sample_gaussian_mixture,
-    'ramp': sample_ramp,
-    'two_discs': sample_two_discs,
-    'disc_grid': sample_disc_grid,
-    'spiral': sample_spiral,
-    'mnist': sample_mnist,
+    f.__name__.removeprefix('sample_') : f for f in [
+        sample_gaussian_mixture,
+        sample_ramp,
+        sample_two_discs,
+        sample_disc_grid,
+        sample_spiral,
+        sample_mnist,
+    ]
 }
 
 ######################################################################
@@ -289,6 +291,9 @@ model.eval()
 if train_input.dim() == 2 and train_input.size(1) == 1:
 
     fig = plt.figure()
+    fig.set_figheight(5)
+    fig.set_figwidth(8)
+
     ax = fig.add_subplot(1, 1, 1)
 
     x = generate((10000, 1), T, alpha, alpha_bar, sigma,
@@ -310,12 +315,12 @@ if train_input.dim() == 2 and train_input.size(1) == 1:
 
     ax.legend(frameon = False, loc = 2)
 
-    filename = f'diffusion_{args.data}.pdf'
+    filename = f'minidiffusion_{args.data}.pdf'
     print(f'saving {filename}')
     fig.savefig(filename, bbox_inches='tight')
 
     if not args.no_window and hasattr(plt.get_current_fig_manager(), 'window'):
-        plt.get_current_fig_manager().window.setGeometry(2, 2, 2048, 768)
+        plt.get_current_fig_manager().window.setGeometry(2, 2, 1024, 768)
         plt.show()
 
 ########################################
@@ -323,6 +328,9 @@ if train_input.dim() == 2 and train_input.size(1) == 1:
 elif train_input.dim() == 2 and train_input.size(1) == 2:
 
     fig = plt.figure()
+    fig.set_figheight(6)
+    fig.set_figwidth(6)
+
     ax = fig.add_subplot(1, 1, 1)
 
     x = generate((1000, 2), T, alpha, alpha_bar, sigma,
@@ -344,7 +352,7 @@ elif train_input.dim() == 2 and train_input.size(1) == 2:
 
     ax.legend(frameon = False, loc = 2)
 
-    filename = f'diffusion_{args.data}.pdf'
+    filename = f'minidiffusion_{args.data}.pdf'
     print(f'saving {filename}')
     fig.savefig(filename, bbox_inches='tight')
 
@@ -369,7 +377,7 @@ elif train_input.dim() == 4:
 
     result = 1 - torch.cat((t, x), 2) / 255
 
-    filename = f'diffusion_{args.data}.png'
+    filename = f'minidiffusion_{args.data}.png'
     print(f'saving {filename}')
     torchvision.utils.save_image(result, filename)