From a2ffcd9b27aa0f3cc0b56090a32e88b73dfa0a54 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 08:49:15 +0200 Subject: [PATCH] Update. --- tasks.py | 47 ++++++++++++++++++++++++++++++++--------------- world.py | 7 ++++--- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/tasks.py b/tasks.py index eef84af..fb85576 100755 --- a/tasks.py +++ b/tasks.py @@ -79,7 +79,7 @@ class ProblemByheart(Problem): def generate_sequences(self, nb): sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] - ar_mask = (sequences==10).long() + ar_mask = (sequences == 10).long() ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask @@ -87,16 +87,17 @@ class ProblemByheart(Problem): # nb_common_codes = 100 # def generate_sequences(nb_samples): - # problem_indexes = torch.randint(len(problems), (nb_samples,)) - # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - # print(f"{nb_samples_per_problem}") - # all_seq = [] - # for nb, p in zip(nb_samples_per_problem, problems): - # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - # return all_seq + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq # for strain, stest in zip(train_seq, test_seq): - # s = torch.cat((strain, stest), 0) + # s = torch.cat((strain, stest), 0) + class SandBox(Task): def __init__( @@ -107,17 +108,29 @@ class SandBox(Task): batch_size, logger=None, device=torch.device("cpu"), + max_nb_codes=1024, ): super().__init__() self.batch_size = batch_size self.device = device - self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.train_input, self.train_ar_mask = problem.generate_sequences( + nb_train_samples + ) self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 + # A bit of paranoia never hurts + assert ( + self.nb_codes <= max_nb_codes + and self.train_input.min() >= 0 + and self.test_input.min() >= 0 + and tuple(self.train_ar_mask.unique()) == (0, 1) + and tuple(self.test_ar_mask.unique()) == (0, 1) + ) + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -136,9 +149,8 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - def compute_accuracy(input, ar_mask): - result = input.clone() * (1-ar_mask) + result = input.clone() * (1 - ar_mask) masked_inplace_autoregression( model, self.batch_size, @@ -150,22 +162,27 @@ class SandBox(Task): ) nb_total = ar_mask.sum().item() - nb_correct = ((result==input).long() * ar_mask).sum().item() + nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + train_nb_total, train_nb_correct = compute_accuracy( + self.train_input, self.train_ar_mask + ) logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + test_nb_total, test_nb_correct = compute_accuracy( + self.test_input, self.test_ar_mask + ) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" ) + ###################################################################### import picoclvr diff --git a/world.py b/world.py index 12c6553..1d64fa3 100755 --- a/world.py +++ b/world.py @@ -61,12 +61,13 @@ class SignSTE(nn.Module): else: return s + class DiscreteSampler2d(nn.Module): def __init__(self): super().__init__() def forward(self, x): - s = (x >= x.max(-3,keepdim=True).values).float() + s = (x >= x.max(-3, keepdim=True).values).float() if self.training: u = x.softmax(dim=-3) @@ -96,7 +97,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - mu, std = train_input.float().mean(), train_input.float().std() def encoder_core(depth, dim): @@ -459,7 +459,8 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - 25000, 1000, + 25000, + 1000, nb_epochs=5, mode="first_last", nb_steps=20, -- 2.20.1