X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=svrtset.py;h=8bfdde199bd1c858d30ee42fe3e0ebbe30eeb860;hb=faac6461d3482204898ad97f811f58889cc37e1d;hp=49f32799da7b0ba4a880d9aacc90278617908220;hpb=e754d1075d8d0a5949e71f426ab07ce73be6099e;p=pysvrt.git diff --git a/svrtset.py b/svrtset.py index 49f3279..8bfdde1 100755 --- a/svrtset.py +++ b/svrtset.py @@ -48,10 +48,11 @@ class VignetteSet: raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size seeds = torch.LongTensor(self.nb_batches).random_() mp_args = [] @@ -96,10 +97,12 @@ class CompressedVignetteSet: raise self.cuda = cuda - self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_samples // batch_size - self.nb_samples = self.nb_batches * self.batch_size + + self.batch_size = batch_size + self.nb_samples = nb_samples + self.nb_batches = self.nb_samples // self.batch_size + self.targets = [] self.input_storages = []