X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=main.py;h=82caebe284b6c966f3b1653aa396091edef75e46;hb=e6f60ce506f75da652d069a4926adfa4d92e3cbf;hp=c810eef06593c7939e0ad15fe690225509cd4150;hpb=fc570d4ccd5d5dee36271d34ff5c672a50a82101;p=mygpt.git diff --git a/main.py b/main.py index c810eef..82caebe 100755 --- a/main.py +++ b/main.py @@ -24,9 +24,6 @@ parser = argparse.ArgumentParser(description = 'My own GPT.') parser.add_argument('--log_filename', type = str, default = 'train.log') -parser.add_argument('--download', - action='store_true', default = False) - parser.add_argument('--seed', type = int, default = 0) @@ -78,8 +75,8 @@ parser.add_argument('--checkpoint_name', ############################## # picoclvr options -parser.add_argument('--picoclvr_many_colors', - action='store_true', default = False) +parser.add_argument('--picoclvr_nb_colors', + type = int, default = 5) parser.add_argument('--picoclvr_height', type = int, default = 12) @@ -113,22 +110,34 @@ for n in vars(args): ###################################################################### -def produce_results( - self, - model, nb_samples, nb_tokens_to_generate, starting_input = None, - device = 'cpu' +def autoregression( + model, + nb_samples, nb_tokens_to_generate, starting_input = None, + device = torch.device('cpu') ): - results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + results = torch.zeros( + nb_samples, nb_tokens_to_generate, + dtype = torch.int64, device = device + ) + + if starting_input is None: + first = 0 + else: + first = starting_input.size(1) + results = torch.cat((starting_input, results), 1) + + for input in results.split(args.batch_size): + for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'): output = model(input) logits = output[:, s] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax(1) - input[:, s + 1] = t + t_next = logits.argmax(1) + input[:, s] = t_next + + return results ###################################################################### @@ -149,18 +158,19 @@ import picoclvr class TaskPicoCLVR(Task): def __init__(self, batch_size, - height, width, many_colors = False, + height, width, nb_colors = 5, device = torch.device('cpu')): def generate_descr(nb): descr = picoclvr.generate( nb, height = self.height, width = self.width, - many_colors = many_colors + nb_colors = nb_colors ) descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) + #descr = [ [ '' ] * (l - len(s)) + s for s in descr ] descr = [ s + [ '' ] * (l - len(s)) for s in descr ] return descr @@ -182,6 +192,7 @@ class TaskPicoCLVR(Task): self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + # Tokenize the train and test sets t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] self.train_input = torch.tensor(t, device = self.device) t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] @@ -206,14 +217,15 @@ class TaskPicoCLVR(Task): for j in range(nb_tokens): t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] input = torch.tensor(t, device = self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) + t_next = logits.argmax() + t_generated.append(self.id2token[t_next.item()]) return ' '.join(t_primer + t_generated) @@ -333,14 +345,15 @@ class TaskWiki103(Task): for j in range(nb_tokens): input = self.tensorize([ t_primer + t_generated ]).to(self.device) + input = F.pad(input, (0, 1)) # Add the next token, the one to predict output = model(input) logits = output[0, -1] if args.synthesis_sampling: dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() + t_next = dist.sample() else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) + t_next = logits.argmax() + t_generated.append(self.vocab.lookup_token(t_next)) if t_generated[-1] == '': break s = ' '.join(t_generated) @@ -373,18 +386,7 @@ class TaskMNIST(Task): return 256 def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s] = t - + results = autoregression(model, nb_samples, 28 * 28, device = self.device) image_name = f'result_mnist_{n_epoch:04d}.png' torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., image_name, nrow = 16, pad_value = 0.8) @@ -405,7 +407,7 @@ elif args.data == 'picoclvr': task = TaskPicoCLVR(batch_size = args.batch_size, height = args.picoclvr_height, width = args.picoclvr_width, - many_colors = args.picoclvr_many_colors, + nb_colors = args.picoclvr_nb_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') @@ -443,7 +445,7 @@ else: nb_epochs_finished = 0 if args.no_checkpoint: - log_string(f'Not trying to load checkpoint.') + log_string(f'not trying to load checkpoint.') else: try: @@ -451,13 +453,13 @@ else: nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) - log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: - log_string('Starting from scratch.') + log_string('starting from scratch.') except: - log_string('Error when loading the checkpoint.') + log_string('error when loading the checkpoint.') exit(1) ###################################################################### @@ -470,7 +472,7 @@ for input in task.batches(split = 'train'): token_probas = token_count / token_count.sum() h = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(h) -log_string(f'Train set perplexity {train_set_perplexity}') +log_string(f'train set perplexity {train_set_perplexity}') for k in range(nb_epochs_finished, nb_epochs): @@ -498,7 +500,7 @@ for k in range(nb_epochs_finished, nb_epochs): for input in task.batches(split = 'test'): input = input.to(device) output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0)