X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=8b57cb2056462d8265bb38c1dffb49c4d59dc41a;hb=e3a8032a070175ece08fc79c77312d5f2f59150e;hp=f8fb9b93ace534d6a225558f82b7d2d61211031a;hpb=2ac9d1299a84f96228f49fbdac02d5a7017445e5;p=picoclvr.git diff --git a/tasks.py b/tasks.py index f8fb9b9..8b57cb2 100755 --- a/tasks.py +++ b/tasks.py @@ -60,6 +60,69 @@ class Task: pass +###################################################################### + + +class Problem: + def generate(nb): + pass + + def perf(seq, logger): + pass + + +class ProblemByheart(Problem): + def __init__(self): + pass + + +class SandBox(Task): + def __init__( + self, + nb_train_samples, + nb_test_samples, + batch_size, + logger=None, + device=torch.device("cpu"), + ): + super().__init__() + + self.batch_size = batch_size + + def generate_sequences(nb_samples): + problem_indexes = torch.randint(len(problems), (nb_samples,)) + nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + print(f"{nb_samples_per_problem}") + + self.train_input = generate_sequences(nb_train_samples) + self.test_input = generate_sequences(nb_test_samples) + + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + + def batches(self, split="train", nb_to_use=-1, desc=None): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + if nb_to_use > 0: + input = input[:nb_to_use] + if desc is None: + desc = f"epoch-{split}" + for batch in tqdm.tqdm( + input.split(self.batch_size), dynamic_ncols=True, desc=desc + ): + yield batch + + def vocabulary_size(self): + return self.nb_codes + + def produce_results( + self, n_epoch, model, result_dir, logger, deterministic_synthesis + ): + # logger( + # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + # ) + pass + + ###################################################################### import picoclvr @@ -108,6 +171,8 @@ class PicoCLVR(Task): pruner_train=None, pruner_eval=None, ): + super().__init__() + def generate_descr(nb, cache_suffix, pruner): return picoclvr.generate( nb, @@ -296,6 +361,8 @@ class MNIST(Task): def __init__( self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu") ): + super().__init__() + self.nb_train_samples = (nb_train_samples,) self.nb_test_samples = (nb_test_samples,) self.batch_size = batch_size @@ -366,6 +433,8 @@ class Maze(Task): nb_walls, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.height = height self.width = width @@ -537,6 +606,8 @@ class Snake(Task): prompt_length, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.height = height self.width = width @@ -635,6 +706,8 @@ class Stack(Task): fraction_values_for_train=None, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.nb_steps = nb_steps self.nb_stacks = nb_stacks @@ -782,6 +855,8 @@ class Expr(Task): batch_size, device=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.device = device @@ -961,6 +1036,8 @@ class World(Task): device=torch.device("cpu"), device_storage=torch.device("cpu"), ): + super().__init__() + self.batch_size = batch_size self.device = device