From 2fd030cd849fa7879211128c15d4a1fbf9d6e7f4 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 15 Jun 2017 15:15:13 +0200 Subject: [PATCH] Log the number of parameters. --- cnn-svrt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cnn-svrt.py b/cnn-svrt.py index 1d5e887..3550d85 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -122,6 +122,11 @@ class Net(nn.Module): def train_model(train_input, train_target): model, criterion = Net(), nn.CrossEntropyLoss() + nb_parameters = 0 + for p in model.parameters(): + nb_parameters += p.numel() + log_string('NB_PARAMETERS {:d}'.format(nb_parameters)) + if torch.cuda.is_available(): model.cuda() criterion.cuda() -- 2.20.1