X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=eef84af68d4393b2a0b9adb1efb00743fb0318ee;hb=5366dfd7bd57ec3298d1030f7d5327ff26bc5aad;hp=9cd06ae054ae7e1adee634a9361adb8680d1356c;hpb=0f580d4facb4b4b485d0a38d62d06c0639715b77;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 9cd06ae..eef84af 100755 --- a/tasks.py +++ b/tasks.py @@ -64,10 +64,10 @@ class Task: class Problem: - def generate(nb): + def generate_sequences(self, nb): pass - def perf(seq, logger): + def log_performance(self, sequences, logger): pass @@ -75,15 +75,33 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = -1 + self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): - return self.seq[torch.randint(self.seq.size(0), (nb,))] - + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences==10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + # problems = [ProblemByheart()] + # nb_common_codes = 100 + + # def generate_sequences(nb_samples): + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq + + # for strain, stest in zip(train_seq, test_seq): + # s = torch.cat((strain, stest), 0) class SandBox(Task): def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -93,24 +111,10 @@ class SandBox(Task): super().__init__() self.batch_size = batch_size + self.device = device - problems = [ProblemByheart()] - nb_common_codes = 100 - - def generate_sequences(nb_samples): - problem_indexes = torch.randint(len(problems), (nb_samples,)) - nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - print(f"{nb_samples_per_problem}") - all_seq = [] - for nb, p in zip(nb_samples_per_problem, problems): - all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - return all_seq - - train_seq = generate_sequences(nb_train_samples) - test_seq = generate_sequences(nb_test_samples) - - for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain, stest), 0) + self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -132,11 +136,35 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - # logger( - # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - # ) - pass + def compute_accuracy(input, ar_mask): + result = input.clone() * (1-ar_mask) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + nb_total = ar_mask.sum().item() + nb_correct = ((result==input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ######################################################################