Added --nb_validation_samples and --validation_error_threshold to terminate learning...
authorFrancois Fleuret <francois@fleuret.org>
Mon, 19 Jun 2017 06:33:11 +0000 (08:33 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Mon, 19 Jun 2017 06:33:11 +0000 (08:33 +0200)
cnn-svrt.py

index e5ecf76..f3d350e 100755 (executable)
@@ -56,6 +56,13 @@ parser.add_argument('--nb_train_samples',
 parser.add_argument('--nb_test_samples',
                     type = int, default = 10000)
 
+parser.add_argument('--nb_validation_samples',
+                    type = int, default = 10000)
+
+parser.add_argument('--validation_error_threshold',
+                    type = float, default = 0.0,
+                    help = 'Early training termination criterion')
+
 parser.add_argument('--nb_epochs',
                     type = int, default = 50)
 
@@ -194,7 +201,22 @@ class AfrozeDeepNet(nn.Module):
 
 ######################################################################
 
-def train_model(model, train_set):
+def nb_errors(model, data_set):
+    ne = 0
+    for b in range(0, data_set.nb_batches):
+        input, target = data_set.get_batch(b)
+        output = model.forward(Variable(input))
+        wta_prediction = output.data.max(1)[1].view(-1)
+
+        for i in range(0, data_set.batch_size):
+            if wta_prediction[i] != target[i]:
+                ne = ne + 1
+
+    return ne
+
+######################################################################
+
+def train_model(model, train_set, validation_set):
     batch_size = args.batch_size
     criterion = nn.CrossEntropyLoss()
 
@@ -216,25 +238,24 @@ def train_model(model, train_set):
             loss.backward()
             optimizer.step()
         dt = (time.time() - start_t) / (e + 1)
+
         log_string('train_loss {:d} {:f}'.format(e + 1, acc_loss),
                    ' [ETA ' + time.ctime(time.time() + dt * (args.nb_epochs - e)) + ']')
 
-    return model
+        if validation_set is not None:
+            nb_validation_errors = nb_errors(model, validation_set)
 
-######################################################################
+            log_string('validation_error {:.02f}% {:d} {:d}'.format(
+                100 * nb_validation_errors / validation_set.nb_samples,
+                nb_validation_errors,
+                validation_set.nb_samples)
+            )
 
-def nb_errors(model, data_set):
-    ne = 0
-    for b in range(0, data_set.nb_batches):
-        input, target = data_set.get_batch(b)
-        output = model.forward(Variable(input))
-        wta_prediction = output.data.max(1)[1].view(-1)
+            if nb_validation_errors / validation_set.nb_samples <= args.validation_error_threshold:
+                log_string('below validation_error_threshold')
+                break
 
-        for i in range(0, data_set.batch_size):
-            if wta_prediction[i] != target[i]:
-                ne = ne + 1
-
-    return ne
+    return model
 
 ######################################################################
 
@@ -329,7 +350,15 @@ for problem_number in map(int, args.problems.split(',')):
             train_set.nb_samples / (time.time() - t))
         )
 
-        train_model(model, train_set)
+        if args.validation_error_threshold > 0.0:
+            validation_set = VignetteSet(problem_number,
+                                         args.nb_validation_samples, args.batch_size,
+                                         cuda = torch.cuda.is_available(),
+                                         logger = vignette_logger())
+        else:
+            validation_set = None
+
+        train_model(model, train_set, validation_set)
         torch.save(model.state_dict(), model_filename)
         log_string('saved_model ' + model_filename)
 
@@ -353,10 +382,6 @@ for problem_number in map(int, args.problems.split(',')):
                                args.nb_test_samples, args.batch_size,
                                cuda = torch.cuda.is_available())
 
-        log_string('data_generation {:0.2f} samples / s'.format(
-            test_set.nb_samples / (time.time() - t))
-        )
-
         nb_test_errors = nb_errors(model, test_set)
 
         log_string('test_error {:d} {:.02f}% {:d} {:d}'.format(