######################################################################
-class vanilla_attention(q, k, v):
+def vanilla_attention(q, k, v):
a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
a = a.softmax(dim=3)
y = torch.einsum("nhts,nhsd->nhtd", a, v)
-
- # y = flex_attention(q, k, v, score_mod=noop)
-
y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
-
return y
-vanilla_attention = torch.compille(vanilla_attention)
+vanilla_attention = torch.compile(vanilla_attention)
+
+# y = flex_attention(q, k, v, score_mod=noop)
class MHAttention(nn.Module):
def noop(score, b, h, q_idx, kv_idx):
return score
- y = vanilla_attention(q, k, v, score_mod=noop)
+ y = vanilla_attention(q, k, v)
# y = flex_attention(q, k, v, score_mod=noop)
y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
m.weight.fill_(1.0)
def forward(self, x):
- x = 2 * x[:, :, 0] + x[:, :, 1]
x = self.embedding(x)
x = self.positional_encoding(x)
x = self.trunk(x)
parser.add_argument("--nb_train_samples", type=int, default=25000)
-parser.add_argument("--nb_test_samples", type=int, default=1000)
+parser.add_argument("--nb_test_samples", type=int, default=10000)
parser.add_argument("--nb_train_alien_samples", type=int, default=0)
######################################################################
-for n_epoch in range(current_epoch, args.nb_epochs):
- start_time = time.perf_counter()
+def save_models(models, suffix=""):
+ if suffix is not "":
+ suffix = "_" + suffix
+ for model in models:
+ filename = f"ae_{model.id:03d}{suffix}.pth"
+ torch.save(
+ {
+ "state_dict": model.state_dict(),
+ "optimizer_state_dict": model.optimizer.state_dict(),
+ "test_accuracy": model.test_accuracy,
+ },
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
+
+######################################################################
+
+for n_epoch in range(current_epoch, args.nb_epochs):
state = {
"current_epoch": n_epoch,
"c_quizzes": c_quizzes,
and time_train >= time_c_quizzes
):
if c_quizzes is None:
- for model in models:
- filename = f"ae_{model.id:03d}_naive.pth"
- torch.save(
- {
- "state_dict": model.state_dict(),
- "optimizer_state_dict": model.optimizer.state_dict(),
- "test_accuracy": model.test_accuracy,
- },
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
-
- # --------------------------------------------------------------------
+ save_models(models, "naive")
last_n_epoch_c_quizzes = n_epoch
nb_gpus = len(gpus)
nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
- # --------------------------------------------------------------------
+ start_time = time.perf_counter()
c_quizzes, agreements = multithread_execution(
generate_ae_c_quizzes,
[(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
)
- # --------------------------------------------------------------------
-
- filename = f"culture_c_quiz_{n_epoch:04d}.png"
save_c_quizzes_with_scores(
- models, c_quizzes[:256], filename, solvable_only=False
+ models,
+ c_quizzes[:256],
+ f"culture_c_quiz_{n_epoch:04d}.png",
+ solvable_only=False,
)
- filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png"
save_c_quizzes_with_scores(
- models, c_quizzes[:256], filename, solvable_only=True
+ models,
+ c_quizzes[:256],
+ f"culture_c_quiz_{n_epoch:04d}_solvable.png",
+ solvable_only=True,
)
log_string(f"generated_c_quizzes {c_quizzes.size()=}")
time_train = 0
+
for model in models:
model.test_accuracy = 0
ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
weakest_models = ranked_models[: len(gpus)]
+ start_time = time.perf_counter()
+
multithread_execution(
one_ae_epoch,
[
# --------------------------------------------------------------------
- for model in models:
- filename = f"ae_{model.id:03d}.pth"
- torch.save(
- {
- "state_dict": model.state_dict(),
- "optimizer_state_dict": model.optimizer.state_dict(),
- "test_accuracy": model.test_accuracy,
- },
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
-
- # --------------------------------------------------------------------
+ save_models(models)
duration = time.perf_counter() - start_time
str_duration = ""