X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=f7d03cfcfab0ef10ab5066729d37b1a09563ff6d;hb=HEAD;hp=a83107b4bda4b2b7f5ba57d2a1db1dcbdd5ff734;hpb=a1a7cb9e680378db521f2a1e2139db0e2db903de;p=mygpt.git diff --git a/main.py b/main.py index a83107b..f7d03cf 100755 --- a/main.py +++ b/main.py @@ -15,249 +15,325 @@ import mygpt ###################################################################### -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### +parser = argparse.ArgumentParser(description="My own GPT.") -parser = argparse.ArgumentParser(description = 'My own GPT.') +parser.add_argument("--log_filename", type=str, default="train.log") -parser.add_argument('--log_filename', - type = str, default = 'train.log') +parser.add_argument("--seed", type=int, default=0) -parser.add_argument('--download', - action='store_true', default = False) +parser.add_argument("--nb_epochs", type=int, default=None) -parser.add_argument('--seed', - type = int, default = 0) +parser.add_argument("--batch_size", type=int, default=25) -parser.add_argument('--nb_epochs', - type = int, default = -1) +parser.add_argument("--data", type=str, default="wiki103") -parser.add_argument('--batch_size', - type = int, default = 25) +parser.add_argument("--data_size", type=int, default=None) -parser.add_argument('--data', - type = str, default = 'wiki103') +parser.add_argument("--optim", type=str, default="adam") -parser.add_argument('--data_size', - type = int, default = -1) +parser.add_argument("--learning_rate", type=float, default=1e-3) -parser.add_argument('--optim', - type = str, default = 'adam') +parser.add_argument("--learning_rate_end", type=float, default=1e-6) -parser.add_argument('--learning_rate', - type = float, default = 1e-4) +parser.add_argument("--dim_model", type=int, default=None) -parser.add_argument('--dim_model', - type = int, default = 512) +parser.add_argument("--dim_keys", type=int, default=None) -parser.add_argument('--dim_keys', - type = int, default = 64) +parser.add_argument("--dim_hidden", type=int, default=None) -parser.add_argument('--dim_hidden', - type = int, default = 2048) +parser.add_argument("--nb_heads", type=int, default=None) -parser.add_argument('--nb_heads', - type = int, default = 8) +parser.add_argument("--nb_blocks", type=int, default=None) -parser.add_argument('--nb_blocks', - type = int, default = 12) +parser.add_argument("--dropout", type=float, default=0.1) -parser.add_argument('--dropout', - type = float, default = 0.1) +parser.add_argument("--deterministic_synthesis", action="store_true", default=False) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument("--no_checkpoint", action="store_true", default=False) -parser.add_argument('--no_checkpoint', - action='store_true', default = False) - -parser.add_argument('--checkpoint_name', - type = str, default = 'checkpoint.pth') +parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth") ############################## # picoclvr options -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +parser.add_argument("--picoclvr_nb_colors", type=int, default=5) -parser.add_argument('--picoclvr_height', - type = int, default = 12) +parser.add_argument("--picoclvr_height", type=int, default=12) -parser.add_argument('--picoclvr_width', - type = int, default = 16) +parser.add_argument("--picoclvr_width", type=int, default=16) ###################################################################### args = parser.parse_args() -log_file = open(args.log_filename, 'w') +log_file = open(args.log_filename, "w") if args.seed >= 0: torch.manual_seed(args.seed) ###################################################################### + def log_string(s): - t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) + t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime()) if log_file is not None: - log_file.write(t + s + '\n') + log_file.write(t + s + "\n") log_file.flush() print(t + s) sys.stdout.flush() + for n in vars(args): - log_string(f'args.{n} {getattr(args, n)}') + log_string(f"args.{n} {getattr(args, n)}") ###################################################################### -def produce_results( - self, - model, nb_samples, nb_tokens_to_generate, starting_input = None, - device = 'cpu' +default_args = { + "mnist": { + "nb_epochs": 10, + "dim_model": 64, + "dim_keys": 64, + "dim_hidden": 128, + "nb_heads": 4, + "nb_blocks": 6, + }, + "mnist-debug": { + "nb_epochs": 2, + "data_size": 10000, + "dim_model": 8, + "dim_keys": 8, + "dim_hidden": 8, + "nb_heads": 2, + "nb_blocks": 4, + }, + "wiki103": { + "nb_epochs": 25, + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 12, + }, + "picoclvr": { + "nb_epochs": 25, + "dim_model": 512, + "dim_keys": 64, + "dim_hidden": 2048, + "nb_heads": 8, + "nb_blocks": 12, + }, +} + +if args.data in default_args: + for k, v in default_args[args.data].items(): + if getattr(args, k) is None: + setattr(args, k, v) + +###################################################################### + + +def autoregression( + model, + batch_size, + nb_samples, + nb_tokens_to_generate, + primer=None, + device=torch.device("cpu"), ): - results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device + ) + + if primer is None: + first = 0 + else: + first = primer.size(1) + results = torch.cat((primer, results), 1) + + for input in results.split(batch_size): + for s in range(first, input.size(1)): output = model(input) logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + if args.deterministic_synthesis: + t_next = logits.argmax(1) else: - t = logits.argmax(1) - input[:, s + 1] = t + dist = torch.distributions.categorical.Categorical(logits=logits) + t_next = dist.sample() + input[:, s] = t_next + + return results + ###################################################################### + class Task: - def batches(self, split = 'train'): + def batches(self, split="train"): pass def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass + ###################################################################### import picoclvr -class TaskPicoCLVR(Task): - def __init__(self, batch_size, - height, width, many_colors = False, - device = torch.device('cpu')): +class TaskPicoCLVR(Task): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [s.strip().split(" ") for s in descr] + l = max([len(s) for s in token_descr]) + padded_token_descr = [s + [""] * (l - len(s)) for s in token_descr] + id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr] + return torch.tensor(id_descr, device=self.device) + + def trim(self, x, token=""): + n = self.token2id[token] + i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return x[:, a:b] + + def __init__( + self, batch_size, height, width, nb_colors=5, device=torch.device("cpu") + ): def generate_descr(nb): - descr = picoclvr.generate( - nb, - height = self.height, width = self.width, - many_colors = many_colors + return picoclvr.generate( + nb, height=self.height, width=self.width, nb_colors=nb_colors ) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - return descr - self.height = height self.width = width self.batch_size = batch_size self.device = device - nb = args.data_size if args.data_size > 0 else 250000 + nb = args.data_size if args.data_size is not None else 250000 + log_string(f"generating {nb} samples (can take some time)") self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) # Build the tokenizer - tokens = set() - for d in [ self.train_descr, self.test_descr ]: + tokens = {""} + for d in [self.train_descr, self.test_descr]: for s in d: - for t in s: tokens.add(t) - self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) - self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - - t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] - self.train_input = torch.tensor(t, device = self.device) - t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] - self.test_input = torch.tensor(t, device = self.device) - - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + for t in s.strip().split(" "): + tokens.add(t) + self.token2id = dict([(t, n) for n, t in enumerate(tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(tokens)]) + + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) + + def batches(self, split="train"): + assert split in {"train", "test"} + input = self.train_input if split == "train" else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"): + yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) - def generate(self, primer, model, nb_tokens): - t_primer = primer.strip().split(' ') - t_generated = [ ] + def test_model( + self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False + ): + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [] + + for primer_descr in primers_descr: + + results = autoregression( + model, + self.batch_size, + nb_samples=nb_per_primer, + nb_tokens_to_generate=nb_tokens_to_generate, + primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1), + device=self.device, + ) - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - input = F.pad(input, (0, 1)) # Add the next token, the one to predict - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - return ' '.join(t_primer + t_generated) - - def produce_results(self, n_epoch, model, nb_tokens = None): - if nb_tokens is None: - nb_tokens = self.height * self.width + 3 - descr = [ ] - nb_per_primer = 8 - - for primer in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: - - for k in range(nb_per_primer): - descr.append(self.generate(primer, model, nb_tokens)) - - img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] - img = torch.cat(img, 0) - image_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image( - img / 255., - image_name, nrow = nb_per_primer, pad_value = 0.8 + l = [" ".join([self.id2token[t.item()] for t in r]) for r in results] + result_descr += l + + np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width) + + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string( + f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}" ) - log_string(f'wrote {image_name}') - nb_missing = sum( [ - x[2] for x in picoclvr.nb_missing_properties( - descr, - height = self.height, width = self.width + np = torch.tensor(np) + count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64) + for i in range(count.size(0)): + for j in range(count.size(1)): + count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum() + + if generate_images: + img = [ + picoclvr.descr2img(d, height=self.height, width=self.width) + for d in result_descr + ] + + img = torch.cat(img, 0) + image_name = f"result_picoclvr_{n_epoch:04d}.png" + torchvision.utils.save_image( + img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8 ) - ] ) + log_string(f"wrote {image_name}") + + return count + + def produce_results(self, n_epoch, model): + primers_descr = [ + "red above green green top blue right of red ", + "there is red there is yellow there is blue ", + "red below yellow yellow below green green below blue red right yellow left green right blue left ", + "green bottom yellow bottom green left of blue yellow right of blue blue top ", + ] + + self.test_model( + n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True + ) + + # FAR TOO SLOW!!! + + # test_primers_descr=[ s.split('')[0] for s in self.test_descr ] + + # count=self.test_model( + # n_epoch, model, + # test_primers_descr, + # nb_per_primer=1, generate_images=False + # ) + + # with open(f'perf_{n_epoch:04d}.txt', 'w') as f: + # for i in range(count.size(0)): + # for j in range(count.size(1)): + # f.write(f'{count[i,j]}') + # f.write(" " if j 0: + if args.data_size is not None: train_iter = itertools.islice(train_iter, args.data_size) def yield_tokens(): - for l in tqdm.tqdm(train_iter, desc = 'vocab'): + for l in tqdm.tqdm(train_iter, desc="vocab"): yield self.tokenizer(l) self.vocab = torchtext.vocab.build_vocab_from_iterator( - yield_tokens(), - specials = [ '', '' ], - min_freq = self.min_freq + yield_tokens(), specials=["", ""], min_freq=self.min_freq ) - self.vocab.set_default_index(self.vocab[ '' ]) + self.vocab.set_default_index(self.vocab[""]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([self.vocab(x + [""] * (a - len(x))) for x in s]) def yield_batches(self, ds): - s = [ ] + s = [] for l in ds: q = self.tokenizer(l) if len(q) >= self.len_min and len(q) <= self.len_max: - s += [ q ] + s += [q] if len(s) == self.batch_size: yield self.tensorize(s) - s = [ ] + s = [] if len(s) > 0: yield self.tensorize(s) - def batches(self, split = 'train'): - data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') + def batches(self, split="train"): + data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/") # Mostly for debug - if args.data_size > 0: + if args.data_size is not None: data_iter = itertools.islice(data_iter, args.data_size) - return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}')) + return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}")) def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): - file_name = f'result_wiki103_{n_epoch:04d}.txt' - - with open(file_name, 'w') as outfile: - for primer in [ - 'the cat is hunting a', - 'paris is the capital', - 'cars are convenient', - 'the difference between men and women is', - 'the object was blue all over and green all over it was', - 'cherries are red and lemons are', - 'cherries are sweet and lemons are', - 'two plus three equals', - 'deep learning is', - ]: - t_primer = self.tokenizer(primer) - t_generated = [ ] - - for j in range(nb_tokens): - - input = self.tensorize([ t_primer + t_generated ]).to(self.device) - input = F.pad(input, (0, 1)) # Add the next token, the one to predict - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break - - s = ' '.join(t_generated) - - outfile.write(f'<{primer}> {s}\n') - - log_string(f'wrote {file_name}') + def produce_results(self, n_epoch, model): + nb_tokens = 50 + file_name = f"result_wiki103_{n_epoch:04d}.txt" + + with open(file_name, "w") as outfile: + for primer in [ + "the cat is hunting a", + "paris is the capital", + "cars are convenient", + "the difference between men and women is", + "the object was blue all over and green all over it was", + "cherries are red and lemons are", + "cherries are sweet and lemons are", + "two plus three equals", + "deep learning is", + ]: + t_primer = self.tokenizer(primer) + t_generated = [] + + for j in range(nb_tokens): + + input = self.tensorize([t_primer + t_generated]).to(self.device) + input = F.pad( + input, (0, 1) + ) # Add the next token, the one to predict + output = model(input) + logits = output[0, -1] + if args.deterministic_synthesis: + t_next = logits.argmax() + else: + dist = torch.distributions.categorical.Categorical( + logits=logits + ) + t_next = dist.sample() + t_generated.append(self.vocab.lookup_token(t_next)) + if t_generated[-1] == "": + break + + s = " ".join(t_generated) + + outfile.write(f"<{primer}> {s}\n") + + log_string(f"wrote {file_name}") + ###################################################################### -class TaskMNIST(Task): - def __init__(self, batch_size, device = torch.device('cpu')): +class TaskMNIST(Task): + def __init__(self, batch_size, device=torch.device("cpu")): self.device = device self.batch_size = batch_size - def batches(self, split = 'train'): - assert split in { 'train', 'test' } + def batches(self, split="train"): + assert split in {"train", "test"} data_set = torchvision.datasets.MNIST( - root = './data', train = (split == 'train'), - download = True + root="./data", train=(split == "train"), download=True ) data_input = data_set.data.view(-1, 28 * 28).long() - if args.data_size >= 0: - data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'): + if args.data_size is not None: + data_input = data_input[: args.data_size] + for batch in tqdm.tqdm( + data_input.split(self.batch_size), desc=f"epoch-{split}" + ): yield batch def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s] = t - - image_name = f'result_mnist_{n_epoch:04d}.png' - torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., - image_name, nrow = 16, pad_value = 0.8) - log_string(f'wrote {image_name}') + def produce_results(self, n_epoch, model): + nb_samples = 64 + results = autoregression( + model, self.batch_size, nb_samples, 28 * 28, device=self.device + ) + image_name = f"result_mnist_{n_epoch:04d}.png" + torchvision.utils.save_image( + 1 - results.reshape(-1, 1, 28, 28) / 255.0, + image_name, + nrow=16, + pad_value=0.8, + ) + log_string(f"wrote {image_name}") + ###################################################################### -log_string(f'device {device}') - -if args.data == 'wiki103': - nb_epochs_default = 10 - task = TaskWiki103(batch_size = args.batch_size, device = device) -elif args.data == 'mnist': - nb_epochs_default = 25 - task = TaskMNIST(batch_size = args.batch_size, device = device) -elif args.data == 'picoclvr': - nb_epochs_default = 10 - task = TaskPicoCLVR(batch_size = args.batch_size, - height = args.picoclvr_height, - width = args.picoclvr_width, - many_colors = args.picoclvr_many_colors, - device = device) +log_string(f"device {device}") + +if args.data == "wiki103": + task = TaskWiki103(batch_size=args.batch_size, device=device) +elif args.data in {"mnist", "mnist-debug"}: + task = TaskMNIST(batch_size=args.batch_size, device=device) +elif args.data == "picoclvr": + task = TaskPicoCLVR( + batch_size=args.batch_size, + height=args.picoclvr_height, + width=args.picoclvr_width, + nb_colors=args.picoclvr_nb_colors, + device=device, + ) else: - raise ValueError(f'Unknown dataset {args.data}.') + raise ValueError(f"Unknown dataset {args.data}.") vocabulary_size = task.vocabulary_size() -log_string(f'vocabulary_size {vocabulary_size}') +log_string(f"vocabulary_size {vocabulary_size}") ############################## model = mygpt.MyGPT( - vocabulary_size = vocabulary_size, - dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, - nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, ) model.to(device) nb_parameters = sum(p.numel() for p in model.parameters()) -log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') - -###################################################################### - -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'Not trying to load checkpoint.') + log_string(f"not trying to load checkpoint.") else: try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) - log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + checkpoint = torch.load(args.checkpoint_name) + nb_epochs_finished = checkpoint["nb_epochs_finished"] + model.load_state_dict(checkpoint["model_state"]) + torch.set_rng_state(checkpoint["rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"]) + log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.") except FileNotFoundError: - log_string('Starting from scratch.') + log_string("starting from scratch.") except: - log_string('Error when loading the checkpoint.') + log_string("error when loading the checkpoint.") exit(1) ###################################################################### -nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default - token_count = 0 -for input in task.batches(split = 'train'): - token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +for input in task.batches(split="train"): + token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1)) token_probas = token_count / token_count.sum() -h = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(h) -log_string(f'Train set perplexity {train_set_perplexity}') +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +for n_epoch in range(nb_epochs_finished, args.nb_epochs): + + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (args.nb_epochs - 1) + lr = math.exp( + (1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end) + ) + log_string(f"learning_rate {lr}") -for k in range(nb_epochs_finished, nb_epochs): + if args.optim == "sgd": + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + elif args.optim == "adam": + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + elif args.optim == "adamw": + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + else: + raise ValueError(f"Unknown optimizer {args.optim}.") model.train() nb_train_samples, acc_train_loss = 0, 0.0 - for input in task.batches(split = 'train'): + for input in task.batches(split="train"): input = input.to(device) output = model(input) loss = F.cross_entropy(output.transpose(1, 2), input) @@ -497,26 +587,31 @@ for k in range(nb_epochs_finished, nb_epochs): nb_test_samples, acc_test_loss = 0, 0.0 - for input in task.batches(split = 'test'): + for input in task.batches(split="test"): input = input.to(device) output = model(input) loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) - train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}') + log_string( + f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}" + ) - task.produce_results(k, model) + task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': k + 1, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + "nb_epochs_finished": n_epoch + 1, + "model_state": model.state_dict(), + "rng_state": torch.get_rng_state(), } + if torch.cuda.is_available(): + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() + torch.save(checkpoint, args.checkpoint_name) ######################################################################