- def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
self.data = []
for b in range(0, self.nb_batches):
self.data.append(generate_one_batch(mp_args[b]))
self.data = []
for b in range(0, self.nb_batches):
self.data.append(generate_one_batch(mp_args[b]))
- def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False, logger = None):
- self.nb_batches = nb_samples // batch_size
- self.nb_samples = self.nb_batches * self.batch_size
+
+ self.batch_size = batch_size
+ self.nb_samples = nb_samples
+ self.nb_batches = self.nb_samples // self.batch_size
+
for b in range(0, self.nb_batches):
target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)
for b in range(0, self.nb_batches):
target = torch.LongTensor(self.batch_size).bernoulli_(0.5)
input = svrt.generate_vignettes(problem_number, target)