parser.add_argument("--nb_train_samples", type=int, default=50000)
-parser.add_argument("--nb_test_samples", type=int, default=1000)
+parser.add_argument("--nb_test_samples", type=int, default=2500)
parser.add_argument("--nb_c_quizzes", type=int, default=5000)
######################################################################
+def optimizer_to(optim, device):
+ """Move the optimizer optim to the device"""
+ for param in optim.state.values():
+ # Not sure there are any global tensors in the state dict
+ if isinstance(param, torch.Tensor):
+ param.data = param.data.to(device)
+ if param._grad is not None:
+ param._grad.data = param._grad.data.to(device)
+ elif isinstance(param, dict):
+ for subparam in param.values():
+ if isinstance(subparam, torch.Tensor):
+ subparam.data = subparam.data.to(device)
+ if subparam._grad is not None:
+ subparam._grad.data = subparam._grad.data.to(device)
+
+
+######################################################################
+
+
def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
if c_quizzes is None:
quizzes = problem.generate_w_quizzes(nb_samples)
######################################################################
-def optimizer_to(optim, device):
- """Move the optimizer optim to the device"""
- for param in optim.state.values():
- # Not sure there are any global tensors in the state dict
- if isinstance(param, torch.Tensor):
- param.data = param.data.to(device)
- if param._grad is not None:
- param._grad.data = param._grad.data.to(device)
- elif isinstance(param, dict):
- for subparam in param.values():
- if isinstance(subparam, torch.Tensor):
- subparam.data = subparam.data.to(device)
- if subparam._grad is not None:
- subparam._grad.data = subparam._grad.data.to(device)
-
-
-######################################################################
-
-
def add_hints_imt(imt_set):
"""Set every component of the mask to zero with probability
args.proba_hint, and for each component set to zero, copy the
######################################################################
-def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
- one_epoch(model, n_epoch, c_quizzes, train=True, local_device=local_device)
+def one_complete_epoch(
+ model, n_epoch, train_c_quizzes, test_c_quizzes, local_device=main_device
+):
+ one_epoch(model, n_epoch, train_c_quizzes, train=True, local_device=local_device)
- one_epoch(model, n_epoch, c_quizzes, train=False, local_device=local_device)
+ one_epoch(model, n_epoch, test_c_quizzes, train=False, local_device=local_device)
# Compute the test accuracy
)
-def evaluate_quizzes(quizzes, models, local_device):
+def evaluate_quizzes(quizzes, models, with_hints, local_device):
nb_correct, nb_wrong = 0, 0
for model in models:
model=model,
input=quizzes,
with_noise=False,
- with_hints=True,
+ with_hints=with_hints,
local_device=local_device,
)
nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted)
nb_correct += (nb_mistakes == 0).long()
nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long()
- to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
- nb_wrong >= args.nb_have_to_be_wrong
- )
-
# print("\n\n", nb_correct, nb_wrong)
- return to_keep, nb_correct, nb_wrong
+ return nb_correct, nb_wrong
######################################################################
# Select the ones that are solved properly by some models and
# not understood by others
- to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+ nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
+ with_hints=True,
local_device=local_device,
)
+ to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+ nb_wrong >= args.nb_have_to_be_wrong
+ )
+
nb_validated += to_keep.long().sum().item()
record.append(c_quizzes[to_keep])
for args in arguments:
# To get a different sequence between threads
- # log_string(f"dummy_rand {torch.rand(1)}")
- torch.rand(1)
+ log_string(f"dummy_rand {torch.rand(1)}")
+ # torch.rand(1)
t = threading.Thread(target=threadable_fun, daemon=True, args=args)
threads.append(t)
t.start()
def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
c_quizzes = c_quizzes.to(local_device)
- to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+ nb_correct, nb_wrong = evaluate_quizzes(
quizzes=c_quizzes,
models=models,
+ with_hints=False,
local_device=local_device,
)
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
- # model.gen_test_accuracy = d["gen_test_accuracy"]
- # model.gen_state_dict = d["gen_state_dict"]
- # model.train_c_quiz_bags = d["train_c_quiz_bags"]
- # model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
filename = "state.pth"
log_string(f"successfully loaded {filename}")
current_epoch = state["current_epoch"]
- c_quizzes = state["c_quizzes"]
+ train_c_quizzes = state["train_c_quizzes"]
+ test_c_quizzes = state["test_c_quizzes"]
######################################################################
######################################################################
-main_c_quizzes = None
+train_c_quizzes, test_c_quizzes = None, None
######################################################################
state = {
"current_epoch": n_epoch,
- "main_c_quizzes": main_c_quizzes,
+ "train_c_quizzes": train_c_quizzes,
+ "test_c_quizzes": test_c_quizzes,
}
filename = "state.pth"
lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
- if main_c_quizzes is None:
+ if train_c_quizzes is None:
save_models(models, "naive")
nb_gpus = len(gpus)
log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
- main_c_quizzes = (
+ train_c_quizzes = (
new_c_quizzes
- if main_c_quizzes is None
- else torch.cat([main_c_quizzes, new_c_quizzes])
+ if train_c_quizzes is None
+ else torch.cat([train_c_quizzes, new_c_quizzes])
)
- main_c_quizzes = main_c_quizzes[-args.nb_train_samples :]
+ train_c_quizzes = train_c_quizzes[-args.nb_train_samples :]
+
+ nb_correct, _ = evaluate_quizzes(
+ quizzes=train_c_quizzes,
+ models=models,
+ with_hints=False,
+ local_device=local_device,
+ )
+
+ test_c_quizzes = train_c_quizzes[nb_correct >= args.nb_have_to_be_correct]
for model in models:
model.test_accuracy = 0
- if main_c_quizzes is None:
+ if train_c_quizzes is None:
log_string("no_c_quiz")
else:
- log_string(f"nb_c_quizzes {main_c_quizzes.size(0)}")
+ log_string(f"nb_c_quizzes {train_c_quizzes.size(0)}")
# --------------------------------------------------------------------
multithread_execution(
one_complete_epoch,
[
- (model, n_epoch, main_c_quizzes, gpu)
+ (model, n_epoch, train_c_quizzes, test_c_quizzes, gpu)
for model, gpu in zip(weakest_models, gpus)
],
)