X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=82caebe284b6c966f3b1653aa396091edef75e46;hb=e6f60ce506f75da652d069a4926adfa4d92e3cbf;hp=3bf7587a5f1e45bee2530ca56ce016738b95d117;hpb=046f35f38d629c9854104e855a53f0142449138f;p=mygpt.git diff --git a/main.py b/main.py index 3bf7587..82caebe 100755 --- a/main.py +++ b/main.py @@ -24,14 +24,11 @@ parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - type = bool, default = False) - parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--nb_epochs', - type = int, default = 100) + type = int, default = -1) parser.add_argument('--batch_size', type = int, default = 25) @@ -67,7 +64,25 @@ parser.add_argument('--dropout', type = float, default = 0.1) parser.add_argument('--synthesis_sampling', - type = bool, default = True) + action='store_true', default = True) + +parser.add_argument('--no_checkpoint', + action='store_true', default = False) + +parser.add_argument('--checkpoint_name', + type = str, default = 'checkpoint.pth') + +############################## +# picoclvr options + +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) + +parser.add_argument('--picoclvr_height', + type = int, default = 12) + +parser.add_argument('--picoclvr_width', + type = int, default = 16) ###################################################################### @@ -95,6 +110,37 @@ for n in vars(args): ###################################################################### +def autoregression( + model, + nb_samples, nb_tokens_to_generate, starting_input = None, + device = torch.device('cpu') +): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if starting_input is None: + first = 0 + else: + first = starting_input.size(1) + results = torch.cat((starting_input, results), 1) + + for input in results.split(args.batch_size): + for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + else: + t_next = logits.argmax(1) + input[:, s] = t_next + + return results + +###################################################################### + class Task: def batches(self, split = 'train'): pass @@ -112,34 +158,45 @@ import picoclvr class TaskPicoCLVR(Task): def __init__(self, batch_size, - height = 6, width = 8, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): + def generate_descr(nb): + descr = picoclvr.generate( + nb, + height = self.height, width = self.width, + nb_colors = nb_colors + ) + + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + return descr + + self.height = height + self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate( - nb, - height = height, width = width, - many_colors = many_colors - ) - - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + self.train_descr = generate_descr((nb * 4) // 5) + self.test_descr = generate_descr((nb * 1) // 5) + # Build the tokenizer tokens = set() - for s in descr: - for t in s: tokens.add(t) + for d in [ self.train_descr, self.test_descr ]: + for s in d: + for t in s: tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] + # Tokenize the train and test sets + t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] + self.train_input = torch.tensor(t, device = self.device) + t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] + self.test_input = torch.tensor(t, device = self.device) def batches(self, split = 'train'): assert split in { 'train', 'test' } @@ -153,8 +210,29 @@ class TaskPicoCLVR(Task): def vocabulary_size(self): return len(self.token2id) - def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] + def generate(self, primer, model, nb_tokens): + t_primer = primer.strip().split(' ') + t_generated = [ ] + + for j in range(nb_tokens): + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + input = torch.tensor(t, device = self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + else: + t_next = logits.argmax() + t_generated.append(self.id2token[t_next.item()]) + + return ' '.join(t_primer + t_generated) + + def produce_results(self, n_epoch, model, nb_tokens = None): + if nb_tokens is None: + nb_tokens = self.height * self.width + 3 + descr = [ ] nb_per_primer = 8 for primer in [ @@ -165,29 +243,25 @@ class TaskPicoCLVR(Task): ]: for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] + descr.append(self.generate(primer, model, nb_tokens)) + img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 + ) + log_string(f'wrote {image_name}') + + nb_missing = sum( [ + x[2] for x in picoclvr.nb_missing_properties( + descr, + height = self.height, width = self.width + ) + ] ) + + log_string(f'nb_missing {nb_missing / len(descr):.02f}') ###################################################################### @@ -271,14 +345,15 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -311,18 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - + results = autoregression(model, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -330,27 +394,21 @@ class TaskMNIST(Task): ###################################################################### -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - log_string(f'device {device}') if args.data == 'wiki103': + nb_epochs_default = 10 task = TaskWiki103(batch_size = args.batch_size, device = device) elif args.data == 'mnist': + nb_epochs_default = 25 task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) + nb_epochs_default = 10 + task = TaskPicoCLVR(batch_size = args.batch_size, + height = args.picoclvr_height, + width = args.picoclvr_width, + nb_colors = args.picoclvr_nb_colors, + device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -366,11 +424,11 @@ model = mygpt.MyGPT( nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout ) +model.to(device) + nb_parameters = sum(p.numel() for p in model.parameters()) log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') -model.to(device) - ###################################################################### if args.optim == 'sgd': @@ -382,7 +440,41 @@ elif args.optim == 'adamw': else: raise ValueError(f'Unknown optimizer {args.optim}.') -for k in range(args.nb_epochs): +###################################################################### + +nb_epochs_finished = 0 + +if args.no_checkpoint: + log_string(f'not trying to load checkpoint.') + +else: + try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') + + except FileNotFoundError: + log_string('starting from scratch.') + + except: + log_string('error when loading the checkpoint.') + exit(1) + +###################################################################### + +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split = 'train'): + token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +h = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(h) +log_string(f'train set perplexity {train_set_perplexity}') + +for k in range(nb_epochs_finished, nb_epochs): model.train() @@ -391,7 +483,7 @@ for k in range(args.nb_epochs): for input in task.batches(split = 'train'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -408,15 +500,23 @@ for k in range(args.nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}') task.produce_results(k, model) + checkpoint = { + 'nb_epochs_finished': k + 1, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(checkpoint, args.checkpoint_name) + ######################################################################