formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
-parser.add_argument('--nb_train_samples',
- type = int, default = 100000,
+parser.add_argument('--nb_train_batches',
+ type = int, default = 1000,
help = 'How many samples for train')
-parser.add_argument('--nb_test_samples',
- type = int, default = 10000,
+parser.add_argument('--nb_test_batches',
+ type = int, default = 100,
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
type = int, default = 50,
help = 'How many training epochs')
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Mini-batch size')
+
parser.add_argument('--log_file',
type = str, default = 'cnn-svrt.log',
help = 'Log file name')
return x
def train_model(model, train_input, train_target):
+ bs = args.batch_size
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
criterion.cuda()
- optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
+ optimizer = optim.SGD(model.parameters(), lr = 1e-2)
for k in range(0, args.nb_epochs):
acc_loss = 0.0
######################################################################
-def nb_errors(model, data_input, data_target, bs = 100):
- ne = 0
+def nb_errors(model, data_input, data_target):
+ bs = args.batch_size
+ ne = 0
for b in range(0, data_input.size(0), bs):
output = model.forward(data_input.narrow(0, b, bs))
wta_prediction = output.data.max(1)[1].view(-1)
log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
for problem_number in range(1, 24):
- train_input, train_target = generate_set(problem_number, args.nb_train_samples)
- test_input, test_target = generate_set(problem_number, args.nb_test_samples)
+ train_input, train_target = generate_set(problem_number,
+ args.nb_train_batches * args.batch_size)
+ test_input, test_target = generate_set(problem_number,
+ args.nb_test_batches * args.batch_size)
model = AfrozeShallowNet()
if torch.cuda.is_available():