+ nb_parameters = 0
+ for p in model.parameters():
+ nb_parameters += p.numel()
+ log_string('nb_parameters {:d}'.format(nb_parameters))
+
+ model_filename = 'model_' + str(problem_number) + '.param'
+
+ try:
+ model.load_state_dict(torch.load(model_filename))
+ log_string('loaded_model ' + model_filename)
+ except:
+ log_string('training_model')
+ train_model(model, train_input, train_target)
+ torch.save(model.state_dict(), model_filename)
+ log_string('saved_model ' + model_filename)