X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=minidiffusion.py;h=42dff7ca463a82b5b75455445c99145018f43820;hb=9005785d9bc97380f256d7845cdbd4a7de4b3371;hp=6fd85643e1b83d6555c251e55e19146231e4e2a4;hpb=f419d073353e4b468eeb0a4fc82dee37e447b341;p=pytorch.git diff --git a/minidiffusion.py b/minidiffusion.py index 6fd8564..42dff7c 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -272,6 +272,8 @@ if train_input.dim() == 2: x = generate((10000, 1), model) ax.set_xlim(-1.25, 1.25) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) d = train_input.flatten().detach().to('cpu').numpy() ax.hist(d, 25, (-1, 1), @@ -297,11 +299,11 @@ if train_input.dim() == 2: d = x.detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - facecolors = 'none', color = 'red', label = 'Synthesis') + s = 2.0, color = 'red', label = 'Synthesis') d = train_input[:x.size(0)].detach().to('cpu').numpy() ax.scatter(d[:, 0], d[:, 1], - s = 1.0, color = 'blue', label = 'Train') + s = 2.0, color = 'gray', label = 'Train') ax.legend(frameon = False, loc = 2)