From 8ea0e3c5cc303718a8b508b656f7aa9e64ea3070 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 27 Jul 2022 11:18:06 +0200 Subject: [PATCH] Update. --- main.py | 6 +++--- result_picoclvr_0007.png | Bin 623 -> 0 bytes result_picoclvr_0009.png | Bin 0 -> 872 bytes 3 files changed, 3 insertions(+), 3 deletions(-) delete mode 100644 result_picoclvr_0007.png create mode 100644 result_picoclvr_0009.png diff --git a/main.py b/main.py index 6592204..b2adf98 100755 --- a/main.py +++ b/main.py @@ -111,7 +111,7 @@ for n in vars(args): ###################################################################### def autoregression( - model, + model, batch_size, nb_samples, nb_tokens_to_generate, starting_input = None, device = torch.device('cpu') ): @@ -126,7 +126,7 @@ def autoregression( first = starting_input.size(1) results = torch.cat((starting_input, results), 1) - for input in results.split(args.batch_size): + for input in results.split(batch_size): for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] @@ -386,7 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = autoregression(model, nb_samples, 28 * 28, device = self.device) + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) diff --git a/result_picoclvr_0007.png b/result_picoclvr_0007.png deleted file mode 100644 index 7baee570f0d73ac7bd89c0fbc4accc3dd2cafabc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 623 zcmV-#0+9WQP)^k{x&ByR}Zh&)Xb-H*OL9l&Skd;aOVPIH)< z(los=mxzd7FT`c0ls0bPQ%XckluX@4W%iFh?#c!(FS-d=CyR7=b%>ixxRd+VMnGi!bHT( zwebUc+l_24QL6yTSVB_BM_^`@1`(yTScKn}M_P)-bv*huB>i4;lAm7BKO=r{`$lBr zHW69K{9m{g;8Ts;!E?AVElxw;k49)8o?k+yY%^LDj(Lg9LIK$y?P5TDTJ+xb+Xin< zGRqW*;a1i5i2i6-MY%x1<{@yT*{QLI%%}-HU^;V7V7jTIXJOpl3(B-vHpVQ+XO86{ zP+bfH5dod=VZxOy3;DtAJ5k2%zVBd}wMAEnqY>TfqROkiyS18t`syVJlh?hpM2j4q zia`RB#p&|Kg!;M;cTHQ1REZUj3f^7!@V>(Y+e4&vnPu3H!fP#pB?-$(Tb*YazvHrO^nq=9Eblkr002ov JPDHLkV1lCSGkX94 diff --git a/result_picoclvr_0009.png b/result_picoclvr_0009.png new file mode 100644 index 0000000000000000000000000000000000000000..18dad2782bd2990f67065a4b48c3cf6aa1fbe6d5 GIT binary patch literal 872 zcmV-u1DE`XP)F$Iq`8l0%CuEk$Kct;y893g~|{-qE?q6!+&4ElN0kG`2<>KcttR2b8G3G4;R zwT%=7GA?OmsE!fx^BB04P3kcCmKB3u5};%yiJng+G$`kP`vVJvtvRs89S#n+)hI+JlP;6 zDU;_m!MZ4!5W}Z6YvG9{zf5QY&6~d7HAFF(eNIVkQ}4(?J{%8!xTT+%O*5T(MuB_{ zZg|=>-kmy)nEwkLS^nA<7iGoX0hkiQXPcvh$vXfzDg1W|v-qWxZG&&h$*Hi zFD}ipnxLo^!}~C5y_7&sQ^}e<0ygswo*3G^4&^B*5z_GUB z+bnER<5nFV54X&6q8h)tMvaiS*ekI4K6 yX-js^Vy%E4$LU{qBoRZ##5({};%-}#FTsC_>{b4z$`0HB0000