parser.add_argument("--learning_rate", type=float, default=5e-4)
+parser.add_argument("--lambda_H", type=float, default=0.0)
+
parser.add_argument("--schedule_free", action="store_true", default=False)
# ----------------------------------
targets = input
output = model(mygpt.BracketedSequence(input)).x
+
loss_per_token = F.cross_entropy(
output.transpose(1, 2), targets, reduction="none"
)
+
+ # warnings.warn("entropy masking", RuntimeWarning)
+ # l = output.transpose(1, 2).log_softmax(dim=1)
+ # H = -(l * l.exp()).sum(dim=1)
+ # M = (H >= -math.log(0.99) / H.size(1)).long()
+ # print(H, M)
+ # loss_per_token = loss_per_token * M
+
loss = (loss_per_token * mask_loss).mean() + model.loss
+
acc_train_loss += loss.item() * input.size(0)
loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
return bs
-if args.test == "func":
- test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
-
- L = test_input.size(1) // 4
- f_len = 50
-
- model = Thinker(
- vocabulary_size=vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- f_len=f_len,
- dropout=args.dropout,
- ).to(main_device)
-
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
- for n_epoch in range(args.nb_epochs):
- model.train()
-
- train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- for input in tqdm.tqdm(
- train_input.split(args.batch_size),
- dynamic_ncols=True,
- desc="training",
- total=train_input.size(0) // args.batch_size,
- ):
- input = input.to(main_device)
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.zero_grad()
-
- output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
- targets = input[:, 3 * L :]
- loss = F.cross_entropy(output.transpose(1, 2), targets)
- acc_train_loss += loss.item() * input.size(0)
-
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}")
-
- with torch.autograd.no_grad():
- model.eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
-
- for input in tqdm.tqdm(
- test_input.split(args.batch_size),
- dynamic_ncols=True,
- desc="testing",
- total=test_input.size(0) // args.batch_size,
- ):
- input = input.to(main_device)
-
- output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
- targets = input[:, 3 * L :]
- loss = F.cross_entropy(output.transpose(1, 2), targets)
- acc_test_loss += loss.item() * input.size(0)
-
- nb_test_samples += input.size(0)
-
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
-
- log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}")
-
- input = test_input[:128].clone().to(main_device)
-
- output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
- dist = torch.distributions.categorical.Categorical(logits=output)
- input[:, 3 * L + 1 :] = dist.sample()[:, 1:]
-
- problem.save_quizzes_as_image(
- args.result_dir,
- f"thinker_prediction_{n_epoch:04d}.png",
- quizzes=input,
- # predicted_parts=predicted_parts,
- # correct_parts=correct_parts,
- )
-
-
######################################################################
models = []
model.test_accuracy = 0.0
model.best_test_accuracy = 0.0
-
+ model.best_dict = copy.deepcopy(model.state_dict())
models.append(model)
######################################################################
exit(0)
######################################################################
-######################################################################
-if args.test == "reject":
- record = []
-
- c_quizzes_procedure = [
- (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
- (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
- (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
- (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
- (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
- (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold),
- (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
- ]
- while sum([x.size(0) for x in record]) < 64:
+def save_generated_c_quizzes(model, filename, nb=64):
+ while sum([x.size(0) for x in record]) < nb:
model = models[torch.randint(len(models), (1,)).item()]
c_quizzes = quiz_machine.generate_c_quizzes(
64,
print("NB_KEPT", sum([x.size(0) for x in record]))
- filename = f"sampling_with_rejection.png"
-
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
log_string(f"wrote {filename}")
+
+######################################################################
+
+if args.test == "entropy":
+ model = models[0]
+ model.to(main_device)
+
+ log_string("starting testing entropy maximization")
+
+ train_input = quiz_machine.generate_c_quizzes(
+ 1000,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ )
+
+ for n_epoch in range(10):
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for input in train_input.split(args.batch_size):
+ input = input.to(main_device)
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = output.log_softmax(dim=1).mean()
+
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ model.optimizer.zero_grad()
+ loss.backward()
+ model.optimizer.step()
+
+ log_string(
+ f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
+ )
+
exit(0)
######################################################################
##################################################
# Select, improve, and eval the worst model(s)
- if total_time_training_models < total_time_generating_c_quizzes:
+ if total_time_training_models <= total_time_generating_c_quizzes:
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
total_time_training_models += time.perf_counter() - start_time
+ for model in weakest_models:
+ save_additional_results(n_epoch, model, models, c_quizzes_procedure)
+
# Save the models to disk
for model in models:
)
log_string(f"wrote {filename}")
- for model in weakest_models:
- save_additional_results(n_epoch, model, models, c_quizzes_procedure)
-
######################################################################
if args.log_command is not None: