X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=7a4abbeea23d9494d835cf6a040a36ab2eb53cc2;hb=0d86d8ca945722438d3c85cd01b3740269ed3546;hp=0858282fa3b5cab185d14025d5ad758de44411c2;hpb=16e7952b7cc32ca21498fa3a12fb79f679ea8c21;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 0858282..7a4abbe 100755 --- a/tasks.py +++ b/tasks.py @@ -110,15 +110,20 @@ class SandBox(Task): self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - # A bit of paranoia never hurts - assert ( - self.nb_codes <= max_nb_codes - and self.train_input.min() >= 0 - and self.test_input.min() >= 0 - and tuple(x.item() for x in self.train_ar_mask.unique()) in { (0,), (1,), (0,1) } - and tuple(x.item() for x in self.test_ar_mask.unique()) in { (0,), (1,), (0,1) } - ) + assert self.nb_codes <= max_nb_codes + assert self.train_input.min() >= 0 + assert self.test_input.min() >= 0 + assert tuple(x.item() for x in self.train_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } + assert tuple(x.item() for x in self.test_ar_mask.unique()) in { + (0,), + (1,), + (0, 1), + } def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -152,16 +157,21 @@ class SandBox(Task): device=self.device, ) + log_ground_truth = ar_mask.min() == 0 + if logger is not None: for sp, st in zip(result[:10], input[:10]): logger( f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}" ) - logger( - f" {n_epoch} ground truth {self.problem.seq2str(st)}" - ) + if log_ground_truth: + logger( + f" {n_epoch} ground truth {self.problem.seq2str(st)}" + ) - nb_total, nb_correct = self.problem.compute_nb_correct(input, ar_mask, result) + nb_total, nb_correct = self.problem.compute_nb_correct( + input, ar_mask, result + ) # nb_total = ar_mask.sum().item() # nb_correct = ((result == input).long() * ar_mask).sum().item()