self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
# A bit of paranoia never hurts
- assert (
- self.nb_codes <= max_nb_codes
- and self.train_input.min() >= 0
- and self.test_input.min() >= 0
- and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) }
- and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) }
- )
+ assert self.nb_codes <= max_nb_codes
+ assert self.train_input.min() >= 0
+ assert self.test_input.min() >= 0
+ assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
+ (0,),
+ (1,),
+ (0, 1),
+ }
+ assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
+ (0,),
+ (1,),
+ (0, 1),
+ }
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
device=self.device,
)
+ log_ground_truth = ar_mask.min() == 0
+
if logger is not None:
for sp, st in zip(result[:10], input[:10]):
logger(
f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}"
)
- logger(
- f" {n_epoch} ground truth {self.problem.seq2str(st)}"
- )
+ if log_ground_truth:
+ logger(
+ f" {n_epoch} ground truth {self.problem.seq2str(st)}"
+ )
- nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result)
+ nb_total, nb_correct = self.problem.compute_nb_correct(
+ input, ar_mask, result
+ )
# nb_total = ar_mask.sum().item()
# nb_correct = ((result == input).long() * ar_mask).sum().item()