parser.add_argument("--learning_rate", type=float, default=5e-4)
-parser.add_argument("--lambda_H", type=float, default=0.0)
+parser.add_argument("--reboot", action="store_true", default=False)
parser.add_argument("--schedule_free", action="store_true", default=False)
nb_train_samples, acc_train_loss = 0, 0.0
full_input, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
+ args.nb_train_samples,
+ model.train_c_quiz_bags + common_c_quiz_bags,
+ args.c_quiz_multiplier,
)
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
######################################################################
-def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
- # Save generated quizzes with the successive generation steps
-
- recorder = []
-
- c_quizzes = quiz_machine.generate_c_quizzes(
- 64,
- model_for_generation=model,
- procedure=c_quizzes_procedure,
- recorder=recorder,
- )
-
- # This is nb_quizzes x nb_models
-
- l = [
- quiz_machine.models_logprobas(
- model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- + quiz_machine.models_logprobas(
- model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- for model in models
- ]
-
- seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
- probas = seq_logprobas.exp()
-
- comments = []
-
- for l in seq_logprobas:
- comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
-
- ##
-
- c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1)
- predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1)
- nb_steps = c_quizzes.size(1)
- c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1))
- predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1))
-
- # We have comments only for the final quiz, not the successive
- # steps, so we have to add nb_steps-1 empty comments
-
- steps_comments = []
- for c in comments:
- steps_comments += [""] * (nb_steps - 1) + [c]
-
- filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png"
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=c_quizzes,
- predicted_parts=predicted_parts,
- comments=steps_comments,
- nrow=nb_steps * 2, # two quiz per row
- )
-
- log_string(f"wrote {filename}")
-
-
-######################################################################
-
-
def model_proba_solutions(model, quizzes):
l = (
quiz_machine.models_logprobas(
return l.exp()
+######################################################################
+
+
def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
######################################################################
-models = []
-
-
-def compute_causal_attzero(t_q, t_k):
- return t_q < t_k
+def create_models():
+ models = []
-if args.schedule_free:
- import schedulefree
+ def compute_causal_attzero(t_q, t_k):
+ return t_q < t_k
-for k in range(args.nb_gpts):
- log_string(f"creating model {k}")
+ if args.schedule_free:
+ import schedulefree
+
+ for k in range(args.nb_gpts):
+ log_string(f"creating model {k}")
+
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ compute_attzero=compute_causal_attzero,
+ dropout=args.dropout,
+ ).to(main_device)
+
+ class UpperBoundStd(nn.Module):
+ def __init__(self, std_max=1.0):
+ super().__init__()
+ self.std_max = std_max
+
+ def forward(self, x):
+ std = x.std(dim=-1, keepdim=True)
+ y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
+ return y
+
+ if args.logit_std_max > 0:
+ model.readout.f = nn.Sequential(
+ model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
+ )
- model = mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- compute_attzero=compute_causal_attzero,
- dropout=args.dropout,
- ).to(main_device)
+ model.id = k
+ model.train_c_quiz_bags = []
+ model.test_c_quiz_bags = []
- class UpperBoundStd(nn.Module):
- def __init__(self, std_max=1.0):
- super().__init__()
- self.std_max = std_max
+ if args.schedule_free:
+ model.optimizer = schedulefree.AdamWScheduleFree(
+ model.parameters(), lr=args.learning_rate
+ )
+ else:
+ model.optimizer = torch.optim.Adam(
+ model.parameters(), lr=args.learning_rate
+ )
- def forward(self, x):
- std = x.std(dim=-1, keepdim=True)
- y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
- return y
+ model.test_accuracy = 0.0
+ model.gen_test_accuracy = 0.0
+ model.gen_state_dict = copy.deepcopy(model.state_dict())
+ models.append(model)
- if args.logit_std_max > 0:
- model.readout.f = nn.Sequential(
- model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
- )
+ return models
- model.id = k
- model.train_c_quiz_bags = []
- model.test_c_quiz_bags = []
- if args.schedule_free:
- model.optimizer = schedulefree.AdamWScheduleFree(
- model.parameters(), lr=args.learning_rate
- )
- else:
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+common_c_quiz_bags = []
- model.test_accuracy = 0.0
- model.best_test_accuracy = 0.0
- model.best_dict = copy.deepcopy(model.state_dict())
- models.append(model)
+models = create_models()
######################################################################
model.load_state_dict(d["state_dict"])
model.optimizer.load_state_dict(d["optimizer_state_dict"])
model.test_accuracy = d["test_accuracy"]
- model.best_test_accuracy = d["best_test_accuracy"]
- model.best_dict = d["best_dict"]
+ model.gen_test_accuracy = d["gen_test_accuracy"]
+ model.gen_state_dict = d["gen_state_dict"]
model.train_c_quiz_bags = d["train_c_quiz_bags"]
model.test_c_quiz_bags = d["test_c_quiz_bags"]
log_string(f"successfully loaded {filename}")
######################################################################
if args.nb_new_c_quizzes_for_train is None:
- args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
if args.nb_new_c_quizzes_for_test is None:
- args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
log_string(
f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
- cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models])
- log_string(f"current_best_test_accuracies {cta}")
+ cta = " ".join([f"{float(m.gen_test_accuracy):.04f}" for m in models])
+ log_string(f"current_gen_test_accuracies {cta}")
##################################################
for model in models:
if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
log_string(
- f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}"
+ f"storing_gen model {model.id} accuracy {model.gen_test_accuracy} -> {model.test_accuracy}"
)
- model.best_dict = copy.deepcopy(model.state_dict())
- model.best_test_accuracy = model.test_accuracy
+ model.gen_state_dict = copy.deepcopy(model.state_dict())
+ model.gen_test_accuracy = model.test_accuracy
# we restart
if total_time_generating_c_quizzes == 0:
total_time_training_models = 0
if (
- min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
+ min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
and total_time_training_models >= total_time_generating_c_quizzes
):
+ ######################################################################
+ # Re-initalize if there are enough culture quizzes
+
+ if args.reboot:
+ nb_c_quizzes_per_model = [
+ sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models
+ ]
+
+ m = max(nb_c_quizzes_per_model)
+
+ if m >= args.nb_train_samples:
+ model = models[nb_c_quizzes_per_model.index(m)]
+ common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
+ nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
+ log_string(
+ f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
+ )
+
+ models = create_models()
+ total_time_generating_c_quizzes = 0
+ total_time_training_models = 0
+
for model in models:
model.current_dict = copy.deepcopy(model.state_dict())
- model.load_state_dict(model.best_dict)
+ model.load_state_dict(model.gen_state_dict)
start_time = time.perf_counter()
+
record_new_c_quizzes(
models,
quiz_machine,
args.nb_new_c_quizzes_for_train,
args.nb_new_c_quizzes_for_test,
)
+
total_time_generating_c_quizzes += time.perf_counter() - start_time
- # Force one epoch of training
for model in models:
model.load_state_dict(model.current_dict)
total_time_training_models += time.perf_counter() - start_time
- for model in weakest_models:
- save_additional_results(n_epoch, model, models, c_quizzes_procedure)
-
# Save the models to disk
for model in models:
"state_dict": model.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"test_accuracy": model.test_accuracy,
- "best_test_accuracy": model.best_test_accuracy,
- "best_dict": model.best_dict,
+ "gen_test_accuracy": model.gen_test_accuracy,
+ "gen_state_dict": model.gen_state_dict,
"train_c_quiz_bags": model.train_c_quiz_bags,
"test_c_quiz_bags": model.test_c_quiz_bags,
},