- test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
- test_action_seq += nb_frame_codes
- self.test_input = torch.cat(
- (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_after {e}")
+
+ logger(f"----------------------------------------------------------")
+
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
+######################################################################
+
+import qmlp
+
+
+class QMLP(Task):
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ result_dir,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_size = batch_size
+ self.nb_samples_per_mlp = 256
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
+ nb_mlps=nb_train_samples + nb_test_samples,
+ nb_samples=self.nb_samples_per_mlp,
+ device=self.device,
+ batch_size=64,
+ nb_epochs=250,
+ nb_mlps_per_batch=1024,