From: Francois Fleuret Date: Thu, 15 Jun 2017 13:15:13 +0000 (+0200) Subject: Log the number of parameters. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=2fd030cd849fa7879211128c15d4a1fbf9d6e7f4;p=pysvrt.git Log the number of parameters. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index 1d5e887..3550d85 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -122,6 +122,11 @@ class Net(nn.Module): def train_model(train_input, train_target): model, criterion = Net(), nn.CrossEntropyLoss() + nb_parameters = 0 + for p in model.parameters(): + nb_parameters += p.numel() + log_string('NB_PARAMETERS {:d}'.format(nb_parameters)) + if torch.cuda.is_available(): model.cuda() criterion.cuda()