parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--nb_gpts", type=int, default=5)
+
parser.add_argument("--check", action="store_true", default=False)
######################################################################
######################################################################
-if args.test:
- args.nb_train_samples = 1000
- args.nb_test_samples = 25
+if args.check:
+ args.nb_train_samples = 500
+ args.nb_test_samples = 100
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
task.save_image(
new_quizzes[:96],
args.result_dir,
- f"world_new_{n_epoch:04d}_{model.id:02d}.png",
+ f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
log_string,
)
models = []
-for k in range(5):
+for k in range(args.nb_gpts):
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
nb_new_quizzes_for_train = 1000
nb_new_quizzes_for_test = 100
-if args.test:
+if args.check:
accuracy_to_make_quizzes = 0.0
nb_new_quizzes_for_train = 10
nb_new_quizzes_for_test = 10
self.save_image(
result[:96],
result_dir,
- f"world_result_{n_epoch:04d}_{model.id:02d}.png",
+ f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,
)
device=self.device,
)
- nb_correct += (
- (
- (new_quizzes == result).long()
- * (inverted_quizzes, inverted_result).long()
- )
- .min(dim=-1)
- .values
- )
+ nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
+ inverted_quizzes == inverted_result
+ ).long().min(dim=-1).values
return new_quizzes, nb_correct