- train_input, train_target = generate_set(problem_number, args.nb_train_samples)
- test_input, test_target = generate_set(problem_number, args.nb_test_samples)
+ if args.compress_vignettes:
+ train_set = CompressedVignetteSet(problem_number, args.nb_train_batches, args.batch_size)
+ test_set = CompressedVignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+ else:
+ train_set = VignetteSet(problem_number, args.nb_train_batches, args.batch_size)
+ test_set = VignetteSet(problem_number, args.nb_test_batches, args.batch_size)
+
+ model = AfrozeShallowNet()