X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=eef84af68d4393b2a0b9adb1efb00743fb0318ee;hb=5366dfd7bd57ec3298d1030f7d5327ff26bc5aad;hp=8b57cb2056462d8265bb38c1dffb49c4d59dc41a;hpb=a92a5ca00f4277f7a133fa6cfaada2bc1981f524;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 8b57cb2..eef84af 100755 --- a/tasks.py +++ b/tasks.py @@ -64,21 +64,44 @@ class Task: class Problem: - def generate(nb): + def generate_sequences(self, nb): pass - def perf(seq, logger): + def log_performance(self, sequences, logger): pass class ProblemByheart(Problem): def __init__(self): - pass - + nb_seq, len_prompt, len_result = 100, 5, 5 + self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq[:, len_prompt] = 10 + + def generate_sequences(self, nb): + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences==10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + # problems = [ProblemByheart()] + # nb_common_codes = 100 + + # def generate_sequences(nb_samples): + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq + + # for strain, stest in zip(train_seq, test_seq): + # s = torch.cat((strain, stest), 0) class SandBox(Task): def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -88,14 +111,10 @@ class SandBox(Task): super().__init__() self.batch_size = batch_size + self.device = device - def generate_sequences(nb_samples): - problem_indexes = torch.randint(len(problems), (nb_samples,)) - nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - print(f"{nb_samples_per_problem}") - - self.train_input = generate_sequences(nb_train_samples) - self.test_input = generate_sequences(nb_test_samples) + self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -117,11 +136,35 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - # logger( - # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - # ) - pass + def compute_accuracy(input, ar_mask): + result = input.clone() * (1-ar_mask) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + nb_total = ar_mask.sum().item() + nb_correct = ((result==input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ######################################################################