X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=aa1b51799a159b2f176154d63cce5893eabd38ad;hb=62533ba50393866c15b322074cad836684dd69e7;hp=77c4b9e556fb44eec6b6eca95f15d3e04586137e;hpb=823ed2babf4a7144a1832487e7c911e6933d5647;p=mygpt.git diff --git a/main.py b/main.py index 77c4b9e..aa1b517 100755 --- a/main.py +++ b/main.py @@ -18,20 +18,16 @@ import mygpt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###################################################################### - parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) parser.add_argument('--nb_epochs', - type = int, default = 100) + type = int, default = -1) parser.add_argument('--batch_size', type = int, default = 25) @@ -46,7 +42,10 @@ parser.add_argument('--optim', type = str, default = 'adam') parser.add_argument('--learning_rate', - type = float, default = 1e-4) + type = float, default = 1e-3) + +parser.add_argument('--learning_rate_end', + type = float, default = 1e-6) parser.add_argument('--dim_model', type = int, default = 512) @@ -66,8 +65,11 @@ parser.add_argument('--nb_blocks', parser.add_argument('--dropout', type = float, default = 0.1) -parser.add_argument('--synthesis_sampling', - action='store_true', default = True) +parser.add_argument('--deterministic_synthesis', + action='store_true', default = False) + +parser.add_argument('--no_checkpoint', + action='store_true', default = False) parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') @@ -75,8 +77,8 @@ parser.add_argument('--checkpoint_name', ############################## # picoclvr options -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) parser.add_argument('--picoclvr_height', type = int, default = 12) @@ -110,6 +112,37 @@ for n in vars(args): ###################################################################### +def autoregression( + model, batch_size, + nb_samples, nb_tokens_to_generate, primer = None, + device = torch.device('cpu') +): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if primer is None: + first = 0 + else: + first = primer.size(1) + results = torch.cat((primer, results), 1) + + for input in results.split(batch_size): + for s in range(first, input.size(1)): + output = model(input) + logits = output[:, s] + if args.deterministic_synthesis: + t_next = logits.argmax(1) + else: + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + input[:, s] = t_next + + return results + +###################################################################### + class Task: def batches(self, split = 'train'): pass @@ -117,7 +150,7 @@ class Task: def vocabulary_size(self): pass - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): pass ###################################################################### @@ -126,107 +159,141 @@ import picoclvr class TaskPicoCLVR(Task): + # Make a tensor from a list of strings + def tensorize(self, descr): + token_descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in token_descr ]) + #token_descr = [ [ '' ] * (l - len(s)) + s for s in token_descr ] + token_descr = [ s + [ '' ] * (l - len(s)) for s in token_descr ] + id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ] + return torch.tensor(id_descr, device = self.device) + + def trim(self, x, token = ''): + n = self.token2id[token] + i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0) + a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min() + return x[:, a:b] + def __init__(self, batch_size, - height, width, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): def generate_descr(nb): - descr = picoclvr.generate( + return picoclvr.generate( nb, height = self.height, width = self.width, - many_colors = many_colors + nb_colors = nb_colors ) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - return descr - self.height = height self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 + log_string(f'generating {nb} samples (can take some time)') self.train_descr = generate_descr((nb * 4) // 5) self.test_descr = generate_descr((nb * 1) // 5) - tokens = set() + # Build the tokenizer + tokens = { '' } for d in [ self.train_descr, self.test_descr ]: for s in d: - for t in s: tokens.add(t) + for t in s.strip().split(' '): tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] - self.train_input = torch.tensor(t, device = self.device) - t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] - self.test_input = torch.tensor(t, device = self.device) + # Tokenize the train and test sets + self.train_input = self.tensorize(self.train_descr) + self.test_input = self.tensorize(self.test_descr) def batches(self, split = 'train'): assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'): - yield batch + input = self.train_input if split == 'train' else self.test_input + for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'): + yield self.trim(batch) def vocabulary_size(self): return len(self.token2id) - def generate(self, primer, model, nb_tokens): - t_primer = primer.strip().split(' ') - t_generated = [ ] + def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False): + nb_tokens_to_generate = self.height * self.width + 3 + result_descr = [ ] - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - return ' '.join(t_primer + t_generated) - - def produce_results(self, n_epoch, model, nb_tokens = None): - if nb_tokens is None: - nb_tokens = self.height * self.width + 3 - descr = [ ] - nb_per_primer = 8 - - for primer in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: - - for k in range(nb_per_primer): - descr.append(self.generate(primer, model, nb_tokens)) - - img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ] - img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image( - img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8 + for primer_descr in primers_descr: + + results = autoregression( + model, + self.batch_size, + nb_samples = nb_per_primer, + nb_tokens_to_generate = nb_tokens_to_generate, + primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1), + device = self.device + ) + + l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ] + result_descr += l + + np = picoclvr.nb_properties( + result_descr, + height = self.height, width = self.width ) - log_string(f'wrote {file_name}') - nb_missing = sum( [ - x[2] for x in picoclvr.nb_missing_properties( - descr, - height = self.height, width = self.width + nb_requested_properties, _, nb_missing_properties = zip(*np) + + log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}') + + np=torch.tensor(np) + count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64) + for i in range(count.size(0)): + for j in range(count.size(1)): + count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum() + + if generate_images: + img = [ + picoclvr.descr2img(d, height = self.height, width = self.width) + for d in result_descr + ] + + img = torch.cat(img, 0) + image_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image( + img / 255., + image_name, nrow = nb_per_primer, pad_value = 0.8 ) - ] ) + log_string(f'wrote {image_name}') + + return count + + def produce_results(self, n_epoch, model): + primers_descr = [ + 'red above green green top blue right of red ', + 'there is red there is yellow there is blue ', + 'red below yellow yellow below green green below blue red right yellow left green right blue left ', + 'green bottom yellow bottom green left of blue yellow right of blue blue top ', + ] + + self.test_model( + n_epoch, model, + primers_descr, + nb_per_primer=8, generate_images=True + ) + + # FAR TOO SLOW!!! - log_string(f'nb_missing {nb_missing / len(descr):.02f}') + # test_primers_descr=[ s.split('')[0] for s in self.test_descr ] + + # count=self.test_model( + # n_epoch, model, + # test_primers_descr, + # nb_per_primer=1, generate_images=False + # ) + + # with open(f'perf_{n_epoch:04d}.txt', 'w') as f: + # for i in range(count.size(0)): + # for j in range(count.size(1)): + # f.write(f'{count[i,j]}') + # f.write(" " if j', '' ], + specials = [ '', '' ], min_freq = self.min_freq ) self.vocab.set_default_index(self.vocab[ '' ]) + # makes a tensor from a list of list of tokens def tensorize(self, s): a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) def yield_batches(self, ds): s = [ ] @@ -289,7 +357,8 @@ class TaskWiki103(Task): def vocabulary_size(self): return len(self.vocab) - def produce_results(self, n_epoch, model, nb_tokens = 50): + def produce_results(self, n_epoch, model): + nb_tokens = 50 file_name = f'result_wiki103_{n_epoch:04d}.txt' with open(file_name, 'w') as outfile: @@ -310,15 +379,16 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + if args.deterministic_synthesis: + t_next = logits.argmax() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break + dist = torch.distributions.categorical.Categorical(logits = logits) + t_next = dist.sample() + t_generated.append(self.vocab.lookup_token(t_next)) + if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -349,19 +419,9 @@ class TaskMNIST(Task): def vocabulary_size(self): return 256 - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - + def produce_results(self, n_epoch, model): + nb_samples = 64 + results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -369,30 +429,20 @@ class TaskMNIST(Task): ###################################################################### -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - log_string(f'device {device}') if args.data == 'wiki103': + nb_epochs_default = 10 task = TaskWiki103(batch_size = args.batch_size, device = device) elif args.data == 'mnist': + nb_epochs_default = 25 task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': + nb_epochs_default = 10 task = TaskPicoCLVR(batch_size = args.batch_size, height = args.picoclvr_height, width = args.picoclvr_width, - many_colors = args.picoclvr_many_colors, + nb_colors = args.picoclvr_nb_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -416,36 +466,57 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') ###################################################################### -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -###################################################################### - nb_epochs_finished = 0 -try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) - nb_epochs_finished = checkpoint['nb_epochs_finished'] - model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) - print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') +if args.no_checkpoint: + log_string(f'not trying to load checkpoint.') -except FileNotFoundError: - print('Starting from scratch.') - -except: - print('Error when loading the checkpoint.') - exit(1) +else: + try: + checkpoint = torch.load(args.checkpoint_name) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + torch.set_rng_state(checkpoint['rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint['cuda_rng_state']) + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') + + except FileNotFoundError: + log_string('starting from scratch.') + + except: + log_string('error when loading the checkpoint.') + exit(1) ###################################################################### -for k in range(nb_epochs_finished, args.nb_epochs): +nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default + +token_count = 0 +for input in task.batches(split = 'train'): + token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1)) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +for n_epoch in range(nb_epochs_finished, nb_epochs): + + if args.learning_rate_end < 0: + lr = args.learning_rate + else: + u = n_epoch / (nb_epochs - 1) + lr = math.exp((1 - u) * math.log(args.learning_rate) + + u * math.log(args.learning_rate_end)) + log_string(f'learning_rate {lr}') + + if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = lr) + elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = lr) + elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = lr) + else: + raise ValueError(f'Unknown optimizer {args.optim}.') model.train() @@ -454,7 +525,7 @@ for k in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split = 'train'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) @@ -471,23 +542,26 @@ for k in range(nb_epochs_finished, args.nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') - task.produce_results(k, model) + task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': k + 1, + 'nb_epochs_finished': n_epoch + 1, 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + 'rng_state': torch.get_rng_state(), } + if torch.cuda.is_available(): + checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state() + torch.save(checkpoint, args.checkpoint_name) ######################################################################