self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
def trunk(nb, bottom=True):
- trunk_blocks = []
+ trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)]
la = [
QKVAttention(
nb_heads=nb_heads,
attention_dropout=dropout,
),
- VaswaniPositionalEncoding(len_max=1e5),
]
# if not bottom:
theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
+ # if self.hook_theta is not None:
+ # self.hook_theta(theta_A, theta_B)
+
hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
model.test_accuracy = nb_correct / nb_total
- for f, record in [("prediction", record_d), ("generative", record_nd)]:
+ for f, record in [("prediction", record_d), ("generation", record_nd)]:
filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
result, predicted_parts, correct_parts = (
torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2]
)
log_string(f"wrote {filename}")
+ # Prediction with functional perturbations
+
+ # input, mask_generate, mask_loss = next(
+ # ae_batches(
+ # quiz_machine,
+ # [
+ # (
+ # ("A", "f_A", "B", "f_B"),
+ # (0, 0, 0, 1),
+ # (0, 0, 1, 0),
+ # (0, 0, 0, 1),
+ # ),
+ # ],
+ # local_device,
+ # desc=None,
+ # )
+ # )
+ # targets = input.clone()
+ # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices
+ # def change_theta(theta_A, theta_B):
+ # theta
+ # result = ae_generate(
+ # model, (1 - mask_generate) * input, mask_generate, noise_proba
+ # )
+
######################################################################
duration = time.perf_counter() - start_time
str_duration = ""
if duration >= 60:
- str_duration += f"{int(duration//60)}min"
- duration = duration % 60
- str_duration += f"{duration:.01f}s"
- log_string(f"epoch_duration {str_duration}")
+ str_duration += f"{int(duration)//60}min"
+ str_duration += f"{int(duration)%60}s"
+ str_next = (
+ datetime.datetime.now() + datetime.timedelta(seconds=duration)
+ ).strftime("%H:%M:%S")
+ log_string(f"epoch_duration {str_duration} next_finish {str_next}")