parser.add_argument(
"--task",
type=str,
- default="twotargets",
+ default="world",
help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, greed",
)
########################################
-parser.add_argument("--nb_epochs", type=int, default=50)
+parser.add_argument("--nb_epochs", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=1e-4)
-parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
+parser.add_argument("--learning_rate_schedule", type=str, default=None)
########################################
try:
os.mkdir(args.result_dir)
except FileExistsError:
- if not args.resume:
- print(f"result directory {args.result_dir} already exists")
- exit(1)
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
u = n_epoch / args.nb_epochs * math.pi
learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
else:
- u = {
- int(k): float(v)
- for k, v in [
- tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
- ]
- }
+ if args.learning_rate_schedule is not None:
+ u = {
+ int(k): float(v)
+ for k, v in [
+ tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+ ]
+ }
+ else:
+ u = {}
learning_rate_schedule = {}
learning_rate = args.learning_rate
######################################################################
-accuracy_to_make_quizzes = 0.95
+accuracy_to_make_quizzes = 0.975
-for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+for n_epoch in range(args.nb_epochs):
learning_rate = learning_rate_schedule[n_epoch]
for m in models:
one_epoch(m, task, learning_rate)
test_accuracy = run_tests(m, task, deterministic_synthesis=False)
- if test_accuracy >= accuracy_to_make_quizzes:
- other_models = models.copy()
- other_models.remove(model)
- create_quizzes(other_models, task)
+ if test_accuracy >= accuracy_to_make_quizzes:
+ other_models = models.copy()
+ other_models.remove(m)
+ create_quizzes(m, other_models, task)
# --------------------------------------------