# grid_gray = 192
# thickness = 0
- # background_gray = 255
- # dots = True
+ # background_gray = 240
+ # dots = False
named_colors = [
("white", [background_gray, background_gray, background_gray]),
print(t.__name__)
quizzes = self.generate_w_quizzes_(nb, tasks=[t])
self.save_quizzes_as_image(
- result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow
+ result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
)
nb, nrow = 64, 4
# nb, nrow = 8, 2
+ nb_rows = 13
+
+ c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
+ c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ "c_quizzes.png",
+ c_quizzes,
+ # delta=True,
+ nrow=nrow,
+ margin=10,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ )
+
+ w_quizzes = grids.generate_w_quizzes_(
+ nrow * nb_rows,
+ tasks=[
+ grids.task_replace_color,
+ grids.task_translate,
+ grids.task_grow,
+ grids.task_frame,
+ ],
+ )
+
+ grids.save_quizzes_as_image(
+ "/tmp",
+ "w_quizzes.png",
+ w_quizzes,
+ # delta=True,
+ nrow=nrow,
+ margin=10,
+ # grids=False
+ # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+ )
+
+ exit(0)
# for t in grids.all_tasks:
parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_test_samples", type=int, default=2500)
+parser.add_argument("--nb_test_samples", type=int, default=10000)
parser.add_argument("--nb_c_quizzes", type=int, default=5000)
# Compute the test accuracy
quizzes = generate_quiz_set(
- args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier
+ args.nb_test_samples,
+ c_quizzes=None,
+ c_quiz_multiplier=args.c_quiz_multiplier
+ # args.nb_test_samples, test_c_quizzes, args.c_quiz_multiplier
)
imt_set = samples_for_prediction_imt(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
nb_correct += (nb_mistakes == 0).long()
nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
- # print("\n\n", nb_correct, nb_wrong)
-
return nb_correct, nb_wrong
"state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"test_accuracy": model.test_accuracy,
+ "nb_epochs": model.nb_epochs,
},
os.path.join(args.result_dir, filename),
)
log_string(f"wrote {filename}")
+######################################################################
+
+
+def new_model(i):
+ if args.model_type == "standard":
+ model_constructor = attae.AttentionAE
+ elif args.model_type == "functional":
+ model_constructor = attae.FunctionalAttentionAE
+ else:
+ raise ValueError(f"Unknown model type {args.model_type}")
+
+ model = model_constructor(
+ vocabulary_size=vocabulary_size * 2,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
+ )
+
+ model.id = i
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.test_accuracy = 0.0
+ model.nb_epochs = 0
+
+ return model
+
+
######################################################################
problem = grids.Grids(
models = []
-if args.model_type == "standard":
- model_constructor = attae.AttentionAE
-elif args.model_type == "functional":
- model_constructor = attae.FunctionalAttentionAE
-else:
- raise ValueError(f"Unknown model type {args.model_type}")
-
-
for i in range(args.nb_models):
- model = model_constructor(
- vocabulary_size=vocabulary_size * 2,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- dropout=args.dropout,
- )
-
+ model = new_model(i)
# model = torch.compile(model)
- model.id = i
- model.test_accuracy = 0.0
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
models.append(model)
######################################################################
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
+ model.nb_epochs = d["nb_epochs"]
+
log_string(f"successfully loaded {filename}")
filename = "state.pth"
# The c quizzes used to estimate the test accuracy have to be
# solvable without hints
+
nb_correct, _ = evaluate_quizzes(
quizzes=train_c_quizzes,
models=models,