From 5366dfd7bd57ec3298d1030f7d5327ff26bc5aad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 18 Jul 2023 08:44:21 +0200 Subject: [PATCH] Update. --- main.py | 1 + tasks.py | 80 ++++++++++++++++++++++++++++++++++++++------------------ world.py | 11 +++++--- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index e18887b..3be3d55 100755 --- a/main.py +++ b/main.py @@ -266,6 +266,7 @@ picoclvr_pruner_eval = ( if args.task == "sandbox": task = tasks.SandBox( + tasks.ProblemByheart(), nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, diff --git a/tasks.py b/tasks.py index 9cd06ae..eef84af 100755 --- a/tasks.py +++ b/tasks.py @@ -64,10 +64,10 @@ class Task: class Problem: - def generate(nb): + def generate_sequences(self, nb): pass - def perf(seq, logger): + def log_performance(self, sequences, logger): pass @@ -75,15 +75,33 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:, len_prompt] = -1 + self.seq[:, len_prompt] = 10 def generate_sequences(self, nb): - return self.seq[torch.randint(self.seq.size(0), (nb,))] - + sequences = self.seq[torch.randint(self.seq.size(0), (nb,))] + ar_mask = (sequences==10).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + # problems = [ProblemByheart()] + # nb_common_codes = 100 + + # def generate_sequences(nb_samples): + # problem_indexes = torch.randint(len(problems), (nb_samples,)) + # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) + # print(f"{nb_samples_per_problem}") + # all_seq = [] + # for nb, p in zip(nb_samples_per_problem, problems): + # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + # return all_seq + + # for strain, stest in zip(train_seq, test_seq): + # s = torch.cat((strain, stest), 0) class SandBox(Task): def __init__( self, + problem, nb_train_samples, nb_test_samples, batch_size, @@ -93,24 +111,10 @@ class SandBox(Task): super().__init__() self.batch_size = batch_size + self.device = device - problems = [ProblemByheart()] - nb_common_codes = 100 - - def generate_sequences(nb_samples): - problem_indexes = torch.randint(len(problems), (nb_samples,)) - nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) - print(f"{nb_samples_per_problem}") - all_seq = [] - for nb, p in zip(nb_samples_per_problem, problems): - all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) - return all_seq - - train_seq = generate_sequences(nb_train_samples) - test_seq = generate_sequences(nb_test_samples) - - for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain, stest), 0) + self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples) + self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -132,11 +136,35 @@ class SandBox(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): - # logger( - # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - # ) - pass + def compute_accuracy(input, ar_mask): + result = input.clone() * (1-ar_mask) + masked_inplace_autoregression( + model, + self.batch_size, + result, + ar_mask, + deterministic_synthesis, + progress_bar_desc=None, + device=self.device, + ) + + nb_total = ar_mask.sum().item() + nb_correct = ((result==input).long() * ar_mask).sum().item() + + return nb_total, nb_correct + + train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask) + + logger( + f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" + ) + + test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask) + + logger( + f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" + ) ###################################################################### diff --git a/world.py b/world.py index b35a08e..12c6553 100755 --- a/world.py +++ b/world.py @@ -96,8 +96,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - if logger is None: - logger = lambda s: print(s) mu, std = train_input.float().mean(), train_input.float().std() @@ -157,7 +155,7 @@ def train_encoder( nb_parameters = sum(p.numel() for p in model.parameters()) - logger(f"nb_parameters {nb_parameters}") + logger(f"vqae nb_parameters {nb_parameters}") model.to(device) @@ -209,7 +207,7 @@ def train_encoder( train_loss = acc_train_loss / train_input.size(0) test_loss = acc_test_loss / test_input.size(0) - logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") + logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") sys.stdout.flush() return encoder, quantizer, decoder @@ -378,6 +376,9 @@ def create_data_and_processors( if mode == "first_last": steps = [True] + [False] * (nb_steps + 1) + [True] + if logger is None: + logger = lambda s: print(s) + train_input, train_actions = generate_episodes(nb_train_samples, steps) train_input, train_actions = train_input.to(device_storage), train_actions.to( device_storage @@ -405,6 +406,8 @@ def create_data_and_processors( pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :] z_h, z_w = z.size(2), z.size(3) + logger(f"vqae input {train_input[0].size()} output {z[0].size()}") + def frame2seq(input, batch_size=25): seq = [] p = pow2.to(device) -- 2.39.5