Log the number of parameters.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 13:15:13 +0000 (15:15 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 15 Jun 2017 13:15:13 +0000 (15:15 +0200)
cnn-svrt.py

index 1d5e887..3550d85 100755 (executable)
@@ -122,6 +122,11 @@ class Net(nn.Module):
 def train_model(train_input, train_target):
     model, criterion = Net(), nn.CrossEntropyLoss()
 
+    nb_parameters = 0
+    for p in model.parameters():
+        nb_parameters += p.numel()
+    log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
+
     if torch.cuda.is_available():
         model.cuda()
         criterion.cuda()