From 26ef53ee3769c3b6b92b85d15b5a43cbd18ede07 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Oct 2023 14:26:31 +0200 Subject: [PATCH] Update. --- qmlp.py | 18 +++++++++--------- tasks.py | 49 +++++++++++++++++++------------------------------ 2 files changed, 28 insertions(+), 39 deletions(-) diff --git a/qmlp.py b/qmlp.py index e12f0e1..a7defe4 100755 --- a/qmlp.py +++ b/qmlp.py @@ -224,7 +224,7 @@ def generate_sequence_and_test_set( nb_mlps_per_batch=1024, ): - inputs, q_test_sets = [],[] + seqs, q_test_sets = [],[] for n in range(0,nb_mlps,nb_mlps_per_batch): q_train_set, q_test_set, q_params = generate_sets_and_params( @@ -235,7 +235,7 @@ def generate_sequence_and_test_set( device=device, ) - inputs.append(torch.cat( + seqs.append(torch.cat( [ q_train_set, q_train_set.new_full( @@ -252,10 +252,10 @@ def generate_sequence_and_test_set( q_test_sets.append(q_test_set) - input = torch.cat(inputs) + seq = torch.cat(seqs) q_test_set = torch.cat(q_test_sets) - return input, q_test_set + return seq, q_test_set ###################################################################### @@ -271,7 +271,7 @@ if __name__ == "__main__": data = [] - input, q_test_set = generate_sequence_and_test_set( + seq, q_test_set = generate_sequence_and_test_set( nb_mlps=batch_nb_mlps, nb_samples=nb_samples, device=device, @@ -281,11 +281,11 @@ if __name__ == "__main__": ) end_time = time.perf_counter() - print(f"{input.size(0) / (end_time - start_time):.02f} samples per second") + print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second") - q_train_set = input[:, : nb_samples * 3] - q_params = input[:, nb_samples * 3 + 1 :] - print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}") + q_train_set = seq[:, : nb_samples * 3] + q_params = seq[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}") error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) print(f"train {error_train*100}%") error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) diff --git a/tasks.py b/tasks.py index ea10d7c..066f1bb 100755 --- a/tasks.py +++ b/tasks.py @@ -1570,39 +1570,28 @@ class QMLP(Task): self.device = device self.batch_size = batch_size + self.nb_samples_per_mlp = 256 if logger is not None: logger( f"generating {nb_train_samples+nb_test_samples} samples (can take some time)" ) - self.train_descr = self.grid_factory.generate_samples( - nb_train_samples, lambda r: tqdm.tqdm(r) - ) - self.test_descr = self.grid_factory.generate_samples( - nb_test_samples, lambda r: tqdm.tqdm(r) + seq, q_test_set = generate_sequence_and_test_set( + nb_mlps=nb_train_samples+nb_test_samples, + nb_samples=self.nb_samples_per_mlp, + device=self.device, + batch_size=64, + nb_epochs=250, + nb_mlps_per_batch=1024 ) - # Build the tokenizer - tokens = set() - for d in [self.train_descr, self.test_descr]: - for s in d: - for t in s.strip().split(" "): - tokens.add(t) - # make this set a sorted list to get the same tensors given - # the same descr - tokens = list(tokens) - tokens.sort() - tokens = ["#"] + tokens - self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) - self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) - self.t_nul = self.token2id["#"] - self.t_true = self.token2id["true"] - self.t_false = self.token2id["false"] + self.train_input = seq[:nb_train_samples] + self.train_q_test_set = q_test_set[:nb_train_samples] + self.test_input = seq[nb_train_samples:] + self.test_q_test_set = q_test_set[nb_train_samples:] - # Tokenize the train and test sets - self.train_input = self.str2tensor(self.train_descr) - self.test_input = self.str2tensor(self.test_descr) + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train"): assert split in {"train", "test"} @@ -1613,14 +1602,14 @@ class QMLP(Task): yield self.trim(batch) def vocabulary_size(self): - return len(self.token2id) + return self.nb_codes def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis ): correct = self.test_input[:1000] result = correct.clone() - ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long() + ar_mask = torch.arange(result.size(1)) > self.nb_samples_per_mlp * 3 + 1 result *= 1 - ar_mask # paraaaaanoiaaaaaaa logger(f"----------------------------------------------------------") @@ -1644,11 +1633,11 @@ class QMLP(Task): logger(f"----------------------------------------------------------") - nb_total = ar_mask.sum().item() - nb_correct = ((correct == result).long() * ar_mask).sum().item() + q_train_set = result[:, : nb_samples * 3] + q_params = result[:, nb_samples * 3 + 1 :] + error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17) - logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}") - logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}") + logger(f"{error_test=}") ###################################################################### -- 2.39.5