From: François Fleuret Date: Sat, 8 Jul 2023 08:57:49 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=2175a87ad4304a97c63ac9cca6224d0c0b74c64e;p=culture.git Update. --- diff --git a/README.txt b/README.txt index 231d292..d4cb93d 100644 --- a/README.txt +++ b/README.txt @@ -14,4 +14,4 @@ For the arithmetic expressions experiments # 352M parameters / 2.5M samples, reaches 99.80% after 12 epochs, the learning rate schedule is obviously terrible -./main.py --task=expr --nb_blocks=48 --result_dir=results_expr_48b_d1024_2.5M --dim_model=1024 --nb_train_samples=2500000 +./main.py --task=expr --nb_blocks=48 --dim_model=1024 --nb_train_samples=2500000 --result_dir=results_expr_48b_d1024_2.5M diff --git a/main.py b/main.py index 003028a..56b7e1c 100755 --- a/main.py +++ b/main.py @@ -127,6 +127,8 @@ parser.add_argument("--expr_nb_variables", type=int, default=5) parser.add_argument("--expr_sequence_length", type=int, default=30) +parser.add_argument("--expr_input_file", type=str, default=None) + ###################################################################### args = parser.parse_args() @@ -366,6 +368,20 @@ else: ###################################################################### +if args.task == "expr" and args.expr_input_file is not None: + task.produce_results( + nb_epochs_finished, + model, + args.result_dir, + log_string, + args.deterministic_synthesis, + args.expr_input_file, + ) + + exit(0) + +###################################################################### + nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default # Compute the entropy of the training tokens diff --git a/tasks.py b/tasks.py index 62a8891..912b405 100755 --- a/tasks.py +++ b/tasks.py @@ -776,6 +776,20 @@ import expr class Expr(Task): + def tensorize(self, sequences): + len_max = max([len(x) for x in sequences]) + return torch.cat( + [ + torch.tensor( + [ + [self.char2id[c] for c in s + "#" * (len_max - len(s))] + for s in sequences + ] + ) + ], + 0, + ).to(self.device) + def __init__( self, nb_train_samples, @@ -800,43 +814,17 @@ class Expr(Task): nb_variables=nb_variables, length=sequence_length, ) - self.char2id = dict( - [ - (c, n) - for n, c in enumerate( - set("#" + "".join(train_sequences + test_sequences)) - ) - ] - ) + + symbols = list(set("#" + "".join(train_sequences + test_sequences))) + symbols.sort() + + self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) self.filler, self.space = self.char2id["#"], self.char2id[" "] - len_max = max([len(x) for x in train_sequences]) - self.train_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in train_sequences - ] - ) - ], - 0, - ).to(device) - - len_max = max([len(x) for x in test_sequences]) - self.test_input = torch.cat( - [ - torch.tensor( - [ - [self.char2id[c] for c in s + "#" * (len_max - len(s))] - for s in test_sequences - ] - ) - ], - 0, - ).to(device) + self.train_input = self.tensorize(train_sequences) + self.test_input = self.tensorize(test_sequences) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 @@ -862,7 +850,13 @@ class Expr(Task): return "".join([self.id2char[k.item()] for k in s]) def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis + self, + n_epoch, + model, + result_dir, + logger, + deterministic_synthesis, + input_file=None, ): with torch.autograd.no_grad(): t = model.training @@ -931,7 +925,14 @@ class Expr(Task): ############################################################## # Log a few generated sequences - input = self.test_input[:10] + if input_file is None: + input = self.test_input[:10] + else: + with open(input_file, "r") as f: + sequences = [e.strip() for e in f.readlines()] + sequences = [s + " " + "#" * 50 for s in sequences] + input = self.tensorize(sequences) + result = input.clone() ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) result = (1 - ar_mask) * result + ar_mask * self.filler