formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
-parser.add_argument('--nb_train_batches',
- type = int, default = 1000,
+parser.add_argument('--nb_train_samples',
+ type = int, default = 100000,
help = 'How many samples for train')
-parser.add_argument('--nb_test_batches',
- type = int, default = 100,
+parser.add_argument('--nb_test_samples',
+ type = int, default = 10000,
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
######################################################################
+def int_to_suffix(n):
+ if n > 1000000 and n%1000000 == 0:
+ return str(n//1000000) + 'M'
+ elif n > 1000 and n%1000 == 0:
+ return str(n//1000) + 'K'
+ else:
+ return str(n)
+
+######################################################################
+
+if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0:
+ print('The number of samples must be a multiple of the batch size.')
+ raise
+
for problem_number in range(1, 24):
log_string('**** problem ' + str(problem_number) + ' ****')
model_filename = model.name + '_' + \
str(problem_number) + '_' + \
- str(args.nb_train_batches) + '.param'
+ int_to_suffix(args.nb_train_samples) + '.param'
nb_parameters = 0
for p in model.parameters(): nb_parameters += p.numel()
if args.compress_vignettes:
train_set = CompressedVignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
else:
train_set = VignetteSet(problem_number,
- args.nb_train_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(train_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ train_set.nb_samples / (time.time() - t))
+ )
train_model(model, train_set)
torch.save(model.state_dict(), model_filename)
if args.compress_vignettes:
test_set = CompressedVignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
else:
test_set = VignetteSet(problem_number,
- args.nb_test_batches, args.batch_size,
- cuda=torch.cuda.is_available())
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
- log_string('data_generation {:0.2f} samples / s'.format(test_set.nb_samples / (time.time() - t)))
+ log_string('data_generation {:0.2f} samples / s'.format(
+ test_set.nb_samples / (time.time() - t))
+ )
nb_test_errors = nb_errors(model, test_set)