Log the number of parameters.
[pysvrt.git] / cnn-svrt.py
index 1d5e887..3550d85 100755 (executable)
@@ -122,6 +122,11 @@ class Net(nn.Module):
 def train_model(train_input, train_target):
     model, criterion = Net(), nn.CrossEntropyLoss()
 
+    nb_parameters = 0
+    for p in model.parameters():
+        nb_parameters += p.numel()
+    log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
+
     if torch.cuda.is_available():
         model.cuda()
         criterion.cuda()