parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--inference_batch_size", type=int, default=None)
+
parser.add_argument("--nb_train_samples", type=int, default=None)
parser.add_argument("--nb_test_samples", type=int, default=None)
default_args = {
"model": "37M",
"batch_size": 25,
+ "inference_batch_size": 100,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
}
quiz_machine = quiz_machine.QuizMachine(
problem=problem,
- batch_size=args.physical_batch_size,
+ batch_size=args.inference_batch_size,
result_dir=args.result_dir,
logger=log_string,
device=main_device,
######################################################################
+def train_complexifier(model_gen, model_pred1, model_pred2):
+ samples = []
+ perf = []
+
+ optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for n_epoch in range(args.nb_epochs):
+ for b in range(args.nb_train_samples // args.batch_size):
+ while sum([x.size(0) for x in samples]) < args.batch_size:
+ c_quizzes = quiz_machine.generate_c_quizzes(
+ args.inference_batch_size,
+ model_for_generation=model_gen,
+ procedure=c_quizzes_procedure,
+ )
+ to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+ c_quizzes = c_quizzes[to_keep]
+ if c_quizzes.size(0) > 0:
+ seq_logproba = quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ ) + quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ c_quizzes,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ )
+ probas = seq_logproba.exp()
+ to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
+ probas[:, model_pred2.id] <= args.proba_not_understands
+ )
+ log_string(
+ f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
+ )
+ c_quizzes = c_quizzes[to_keep]
+ if c_quizzes.size(0):
+ samples.append(c_quizzes)
+
+ log_string(f"full batch {sum([x.size(0) for x in samples])}")
+
+ x = torch.cat(samples, dim=0)
+
+ input = x[: args.batch_size]
+ samples = [x[args.batch_size :]]
+
+ # -------------------
+
+ seq_logproba = quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ input,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ ) + quiz_machine.models_logprobas(
+ [model_pred1, model_pred2],
+ input,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ )
+
+ comments = []
+
+ for l in seq_logproba:
+ comments.append(
+ f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+ )
+
+ filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir, filename, input, comments=comments
+ )
+ log_string(f"wrote {filename}")
+
+ # ------------------------
+
+ input = input.to(main_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ output = model_gen(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+
+
+######################################################################
+
+
def train_autoencoder():
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
return model
-if args.autoencoder_dim > 0:
- ae = train_autoencoder()
- exit(0)
+# if args.autoencoder_dim > 0:
+# ae = train_autoencoder()
+# exit(0)
######################################################################
######################################################################
+# DIRTY TEST
+
+# train_complexifier(models[0], models[1], models[2])
+
+# exit(0)
+
+######################################################################
+
for n_epoch in range(current_epoch, args.nb_epochs):
state = {"current_epoch": n_epoch}
filename = "state.pth"