X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=vignette_set.py;h=aef23d875fd285f50f66ca1d14f130fbc30a16d0;hb=91b12f8980a69a99fd6bbdc9b6f6a422dd8cd15a;hp=0b6de7e17d7db8fc4897f844c889acf549cda562;hpb=c80fb2d538e0ccacb2523b762888db5ddada2a6e;p=pysvrt.git diff --git a/vignette_set.py b/vignette_set.py index 0b6de7e..aef23d8 100755 --- a/vignette_set.py +++ b/vignette_set.py @@ -18,7 +18,7 @@ # General Public License for more details. # # You should have received a copy of the GNU General Public License -# along with selector. If not, see . +# along with svrt. If not, see . import torch from math import sqrt @@ -41,11 +41,16 @@ def generate_one_batch(s): class VignetteSet: - def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size seeds = torch.LongTensor(self.nb_batches).random_() @@ -83,11 +88,16 @@ class VignetteSet: ###################################################################### class CompressedVignetteSet: - def __init__(self, problem_number, nb_batches, batch_size, cuda = False): + def __init__(self, problem_number, nb_samples, batch_size, cuda = False): + + if nb_samples%batch_size > 0: + print('nb_samples must be a mutiple of batch_size') + raise + self.cuda = cuda self.batch_size = batch_size self.problem_number = problem_number - self.nb_batches = nb_batches + self.nb_batches = nb_samples // batch_size self.nb_samples = self.nb_batches * self.batch_size self.targets = [] self.input_storages = []