Replace the numbers of samples by numbers of batches of samples.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 19:24:25 +0000 (21:24 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 19:24:25 +0000 (21:24 +0200)
cnn-svrt.py

index e7e4574..ab1b363 100755 (executable)
@@ -44,18 +44,22 @@ parser = argparse.ArgumentParser(
     formatter_class = argparse.ArgumentDefaultsHelpFormatter
 )
 
-parser.add_argument('--nb_train_samples',
-                    type = int, default = 100000,
+parser.add_argument('--nb_train_batches',
+                    type = int, default = 1000,
                     help = 'How many samples for train')
 
-parser.add_argument('--nb_test_samples',
-                    type = int, default = 10000,
+parser.add_argument('--nb_test_batches',
+                    type = int, default = 100,
                     help = 'How many samples for test')
 
 parser.add_argument('--nb_epochs',
                     type = int, default = 50,
                     help = 'How many training epochs')
 
+parser.add_argument('--batch_size',
+                    type = int, default = 100,
+                    help = 'Mini-batch size')
+
 parser.add_argument('--log_file',
                     type = str, default = 'cnn-svrt.log',
                     help = 'Log file name')
@@ -120,12 +124,13 @@ class AfrozeShallowNet(nn.Module):
         return x
 
 def train_model(model, train_input, train_target):
+    bs = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
     if torch.cuda.is_available():
         criterion.cuda()
 
-    optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
+    optimizer = optim.SGD(model.parameters(), lr = 1e-2)
 
     for k in range(0, args.nb_epochs):
         acc_loss = 0.0
@@ -142,9 +147,10 @@ def train_model(model, train_input, train_target):
 
 ######################################################################
 
-def nb_errors(model, data_input, data_target, bs = 100):
-    ne = 0
+def nb_errors(model, data_input, data_target):
+    bs = args.batch_size
 
+    ne = 0
     for b in range(0, data_input.size(0), bs):
         output = model.forward(data_input.narrow(0, b, bs))
         wta_prediction = output.data.max(1)[1].view(-1)
@@ -161,8 +167,10 @@ for arg in vars(args):
     log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
 
 for problem_number in range(1, 24):
-    train_input, train_target = generate_set(problem_number, args.nb_train_samples)
-    test_input, test_target = generate_set(problem_number, args.nb_test_samples)
+    train_input, train_target = generate_set(problem_number,
+                                             args.nb_train_batches * args.batch_size)
+    test_input, test_target = generate_set(problem_number,
+                                           args.nb_test_batches * args.batch_size)
     model = AfrozeShallowNet()
 
     if torch.cuda.is_available():