# SVRT
-from vignette_set import VignetteSet, CompressedVignetteSet
+import vignette_set
######################################################################
print('The number of samples must be a multiple of the batch size.')
raise
+if args.compress_vignettes:
+ VignetteSet = vignette_set.CompressedVignetteSet
+else:
+ VignetteSet = vignette_set.VignetteSet
+
for problem_number in range(1, 24):
- log_string('**** problem ' + str(problem_number) + ' ****')
+ log_string('############### problem ' + str(problem_number) + ' ###############')
if args.deep_model:
model = AfrozeDeepNet()
else:
model = AfrozeShallowNet()
- if torch.cuda.is_available():
- model.cuda()
+ if torch.cuda.is_available(): model.cuda()
model_filename = model.name + '_' + \
str(problem_number) + '_' + \
for p in model.parameters(): nb_parameters += p.numel()
log_string('nb_parameters {:d}'.format(nb_parameters))
+ ##################################################
+ # Tries to load the model
+
need_to_train = False
try:
model.load_state_dict(torch.load(model_filename))
except:
need_to_train = True
+ ##################################################
+ # Train if necessary
+
if need_to_train:
log_string('training_model ' + model_filename)
t = time.time()
- if args.compress_vignettes:
- train_set = CompressedVignetteSet(problem_number,
- args.nb_train_samples, args.batch_size,
- cuda = torch.cuda.is_available())
- else:
- train_set = VignetteSet(problem_number,
- args.nb_train_samples, args.batch_size,
- cuda = torch.cuda.is_available())
+ train_set = VignetteSet(problem_number,
+ args.nb_train_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
log_string('data_generation {:0.2f} samples / s'.format(
train_set.nb_samples / (time.time() - t))
train_set.nb_samples)
)
+ ##################################################
+ # Test if necessary
+
if need_to_train or args.test_loaded_models:
t = time.time()
- if args.compress_vignettes:
- test_set = CompressedVignetteSet(problem_number,
- args.nb_test_samples, args.batch_size,
- cuda = torch.cuda.is_available())
- else:
- test_set = VignetteSet(problem_number,
- args.nb_test_samples, args.batch_size,
- cuda = torch.cuda.is_available())
+ test_set = VignetteSet(problem_number,
+ args.nb_test_samples, args.batch_size,
+ cuda = torch.cuda.is_available())
log_string('data_generation {:0.2f} samples / s'.format(
test_set.nb_samples / (time.time() - t))