X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=fb85576a55168bc48ade65d90b4bddf6561bf33b;hb=a2ffcd9b27aa0f3cc0b56090a32e88b73dfa0a54;hp=eef84af68d4393b2a0b9adb1efb00743fb0318ee;hpb=5366dfd7bd57ec3298d1030f7d5327ff26bc5aad;p=picoclvr.git diff --git a/tasks.py b/tasks.py index eef84af..fb85576 100755 --- a/tasks.py +++ b/tasks.py @@ -79,7 +79,7 @@ class ProblemByheart(Problem): def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] - ar_mask = (sequences==10).long() + ar_mask = (sequences == 10).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask @@ -87,16 +87,17 @@ class ProblemByheart(Problem): # nb_common_codes = 100 # def generate_sequences(nb_samples): - # problem_indexes = torch.randint(len(problems), (nb_samples,)) - # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - # print(f"{nb_samples_per_problem}") - # all_seq = [] - # for nb, p in zip(nb_samples_per_problem, problems): - # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - # return all_seq + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq # for strain, stest in zip(train_seq, test_seq): - # s = torch.cat((strain, stest), 0) + # s = torch.cat((strain, stest), 0) + class SandBox(Task): def __init__( @@ -107,17 +108,29 @@ class SandBox(Task): batch_size, logger=None, device=torch.device("cpu"), + max_nb_codes=1024, ): super().__init__() self.batch_size = batch_size self.device = device - self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.train_input, self.train_ar_mask = problem.generate_sequences( + nb_train_samples + ) self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + # A bit of paranoia never hurts + assert ( + self.nb_codes <= max_nb_codes + and self.train_input.min() >= 0 + and self.test_input.min() >= 0 + and tuple(self.train_ar_mask.unique()) == (0, 1) + and tuple(self.test_ar_mask.unique()) == (0, 1) + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -136,9 +149,8 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - def compute_accuracy(input, ar_mask): - result = input.clone() * (1-ar_mask) + result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( model, self.batch_size, @@ -150,22 +162,27 @@ class SandBox(Task): ) nb_total = ar_mask.sum().item() - nb_correct = ((result==input).long() * ar_mask).sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask + ) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + ###################################################################### import picoclvr