class VignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []