Update.
authorFrancois Fleuret <francois@fleuret.org>
Wed, 27 Jul 2022 09:18:06 +0000 (11:18 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Wed, 27 Jul 2022 09:18:06 +0000 (11:18 +0200)
main.py
result_picoclvr_0007.png [deleted file]
result_picoclvr_0009.png [new file with mode: 0644]

diff --git a/main.py b/main.py
index 6592204..b2adf98 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -111,7 +111,7 @@ for n in vars(args):
 ######################################################################
 
 def autoregression(
-        model,
+        model, batch_size,
         nb_samples, nb_tokens_to_generate, starting_input = None,
         device = torch.device('cpu')
 ):
@@ -126,7 +126,7 @@ def autoregression(
         first = starting_input.size(1)
         results = torch.cat((starting_input, results), 1)
 
-    for input in results.split(args.batch_size):
+    for input in results.split(batch_size):
         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
             output = model(input)
             logits = output[:, s]
@@ -386,7 +386,7 @@ class TaskMNIST(Task):
         return 256
 
     def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = autoregression(model, nb_samples, 28 * 28, device = self.device)
+        results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
                                      image_name, nrow = 16, pad_value = 0.8)
diff --git a/result_picoclvr_0007.png b/result_picoclvr_0007.png
deleted file mode 100644 (file)
index 7baee57..0000000
Binary files a/result_picoclvr_0007.png and /dev/null differ
diff --git a/result_picoclvr_0009.png b/result_picoclvr_0009.png
new file mode 100644 (file)
index 0000000..18dad27
Binary files /dev/null and b/result_picoclvr_0009.png differ