def train_model(train_input, train_target):
model, criterion = Net(), nn.CrossEntropyLoss()
+ nb_parameters = 0
+ for p in model.parameters():
+ nb_parameters += p.numel()
+ log_string('NB_PARAMETERS {:d}'.format(nb_parameters))
+
if torch.cuda.is_available():
model.cuda()
criterion.cuda()