X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=fb85576a55168bc48ade65d90b4bddf6561bf33b;hb=a2ffcd9b27aa0f3cc0b56090a32e88b73dfa0a54;hp=5583fc89827d82be551db14ac9cb601f670b4233;hpb=3dea181a5903a0e577e4830c66405b40f2a2df1d;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 5583fc8..fb85576 100755 --- a/tasks.py +++ b/tasks.py @@ -64,10 +64,10 @@ class Task: class Problem: - def generate(nb): + def generate_sequences(self, nb): pass - def perf(seq, logger): + def log_performance(self, sequences, logger): pass @@ -75,44 +75,62 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:,len_prompt]=-1 + self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): - return self.seq[torch.randint(self.seq.size(0), (nb,))] + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences == 10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + # problems = [ProblemByheart()] + # nb_common_codes = 100 + + # def generate_sequences(nb_samples): + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq + + # for strain, stest in zip(train_seq, test_seq): + # s = torch.cat((strain, stest), 0) + class SandBox(Task): def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, logger=None, device=torch.device("cpu"), + max_nb_codes=1024, ): super().__init__() self.batch_size = batch_size + self.device = device - problems = [ ProblemByheart() ] - nb_common_codes = 100 - - def generate_sequences(nb_samples): - problem_indexes = torch.randint(len(problems), (nb_samples,)) - nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - print(f"{nb_samples_per_problem}") - all_seq = [] - for nb, p in zip(nb_samples_per_problem,problems): - all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - return all_seq - - train_seq = generate_sequences(nb_train_samples) - test_seq = generate_sequences(nb_test_samples) - - for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain,stest),0) + self.train_input, self.train_ar_mask = problem.generate_sequences( + nb_train_samples + ) + self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + # A bit of paranoia never hurts + assert ( + self.nb_codes <= max_nb_codes + and self.train_input.min() >= 0 + and self.test_input.min() >= 0 + and tuple(self.train_ar_mask.unique()) == (0, 1) + and tuple(self.test_ar_mask.unique()) == (0, 1) + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -131,10 +149,38 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - # logger( - # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - # ) - pass + def compute_accuracy(input, ar_mask): + result = input.clone() * (1 - ar_mask) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + nb_total = ar_mask.sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask + ) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ######################################################################