+if args.physical_batch_size is None:
+ args.physical_batch_size = args.batch_size
+else:
+ assert args.batch_size % args.physical_batch_size == 0
+
+assert args.nb_train_samples % args.batch_size == 0
+assert args.nb_test_samples % args.batch_size == 0
+
+if args.problem == "sky":
+ problem = sky.Sky(
+ height=args.sky_height,
+ width=args.sky_width,
+ nb_birds=args.sky_nb_birds,
+ nb_iterations=args.sky_nb_iterations,
+ speed=args.sky_speed,
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ )
+ back_accuracy = False
+elif args.problem == "grids":
+ problem = grids.Grids(
+ max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_tasks,
+ )
+ back_accuracy = True
+else:
+ raise ValueError
+
+problem.save_some_examples(args.result_dir)
+
+quiz_machine = quiz_machine.QuizMachine(
+ problem=problem,
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ back_accuracy=back_accuracy,
+ batch_size=args.physical_batch_size,
+ result_dir=args.result_dir,
+ logger=log_string,
+ device=device,
+)