formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
-parser.add_argument('--nb_train_batches',
- type = int, default = 1000,
+parser.add_argument('--nb_train_samples',
+ type = int, default = 100000,
help = 'How many samples for train')
-parser.add_argument('--nb_test_batches',
- type = int, default = 100,
+parser.add_argument('--nb_test_samples',
+ type = int, default = 10000,
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
######################################################################
+def int_to_suffix(n):
+ if n > 1000000 and n%1000000 == 0:
+ return str(n//1000000) + 'M'
+ elif n > 1000 and n%1000 == 0:
+ return str(n//1000) + 'K'
+ else:
+ return str(n)
+
+######################################################################
+
+if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
+ print('The number of samples must be a multiple of the batch size.')
+ raise
+
for problem_number in range(1, 24):
log_string('**** problem ' + str(problem_number) + ' ****')
model_filename = model.name + '_' + \
str(problem_number) + '_' + \
- str(args.nb_train_batches) + '.param'
+ int_to_suffix(args.nb_train_samples) + '.param'
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
if args.compress_vignettes:
train_set = CompressedVignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
else:
train_set = VignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ train_set.nb_samples / (time.time() - t))
+ )
train_model(model, train_set)
torch.save(model.state_dict(), model_filename)
if args.compress_vignettes:
test_set = CompressedVignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
else:
test_set = VignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ test_set.nb_samples / (time.time() - t))
+ )
nb_test_errors = nb_errors(model, test_set)
class VignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
seeds = torch.LongTensor(self.nb_batches).random_()
######################################################################
class CompressedVignetteSet:
- def __init__(self, problem_number, nb_batches, batch_size, cuda = False):
+ def __init__(self, problem_number, nb_samples, batch_size, cuda = False):
+
+ if nb_samples%batch_size > 0:
+ print('nb_samples must be a mutiple of batch_size')
+ raise
+
self.cuda = cuda
self.batch_size = batch_size
self.problem_number = problem_number
- self.nb_batches = nb_batches
+ self.nb_batches = nb_samples // batch_size
self.nb_samples = self.nb_batches * self.batch_size
self.targets = []
self.input_storages = []