3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = -1)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_many_colors',
82 action='store_true', default = False)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
118 model, nb_samples, nb_tokens_to_generate, starting_input = None,
121 results = torch.zeros(nb_samples, nb_tokens_to_generate, dtype = torch.int64, device = device)
122 for input in results.split(self.batch_size):
123 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
124 output = model(input)
125 logits = output[:, s]
126 if args.synthesis_sampling:
127 dist = torch.distributions.categorical.Categorical(logits = logits)
133 ######################################################################
136 def batches(self, split = 'train'):
139 def vocabulary_size(self):
142 def produce_results(self, n_epoch, model, nb_tokens = 50):
145 ######################################################################
149 class TaskPicoCLVR(Task):
151 def __init__(self, batch_size,
152 height, width, many_colors = False,
153 device = torch.device('cpu')):
155 def generate_descr(nb):
156 descr = picoclvr.generate(
158 height = self.height, width = self.width,
159 many_colors = many_colors
162 descr = [ s.strip().split(' ') for s in descr ]
163 l = max([ len(s) for s in descr ])
164 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
170 self.batch_size = batch_size
172 nb = args.data_size if args.data_size > 0 else 250000
174 self.train_descr = generate_descr((nb * 4) // 5)
175 self.test_descr = generate_descr((nb * 1) // 5)
177 # Build the tokenizer
179 for d in [ self.train_descr, self.test_descr ]:
181 for t in s: tokens.add(t)
182 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
183 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
185 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
186 self.train_input = torch.tensor(t, device = self.device)
187 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
188 self.test_input = torch.tensor(t, device = self.device)
190 def batches(self, split = 'train'):
191 assert split in { 'train', 'test' }
193 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
196 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
199 def vocabulary_size(self):
200 return len(self.token2id)
202 def generate(self, primer, model, nb_tokens):
203 t_primer = primer.strip().split(' ')
206 for j in range(nb_tokens):
207 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
208 input = torch.tensor(t, device = self.device)
209 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
210 output = model(input)
211 logits = output[0, -1]
212 if args.synthesis_sampling:
213 dist = torch.distributions.categorical.Categorical(logits = logits)
217 t_generated.append(self.id2token[t.item()])
219 return ' '.join(t_primer + t_generated)
221 def produce_results(self, n_epoch, model, nb_tokens = None):
222 if nb_tokens is None:
223 nb_tokens = self.height * self.width + 3
228 'red above green <sep> green top <sep> blue right of red <img>',
229 'there is red <sep> there is yellow <sep> there is blue <img>',
230 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
231 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
234 for k in range(nb_per_primer):
235 descr.append(self.generate(primer, model, nb_tokens))
237 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
238 img = torch.cat(img, 0)
239 image_name = f'result_picoclvr_{n_epoch:04d}.png'
240 torchvision.utils.save_image(
242 image_name, nrow = nb_per_primer, pad_value = 0.8
244 log_string(f'wrote {image_name}')
247 x[2] for x in picoclvr.nb_missing_properties(
249 height = self.height, width = self.width
253 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
255 ######################################################################
257 class TaskWiki103(Task):
259 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
260 device = torch.device('cpu')):
262 self.batch_size = batch_size
263 self.len_min = len_min
264 self.len_max = len_max
265 self.min_freq = min_freq
268 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
269 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
272 if args.data_size > 0:
273 train_iter = itertools.islice(train_iter, args.data_size)
276 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
277 yield self.tokenizer(l)
279 self.vocab = torchtext.vocab.build_vocab_from_iterator(
281 specials = [ '<unk>', '<non>' ],
282 min_freq = self.min_freq
285 self.vocab.set_default_index(self.vocab[ '<unk>' ])
287 def tensorize(self, s):
288 a = max(len(x) for x in s)
289 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
291 def yield_batches(self, ds):
294 q = self.tokenizer(l)
295 if len(q) >= self.len_min and len(q) <= self.len_max:
297 if len(s) == self.batch_size:
298 yield self.tensorize(s)
302 yield self.tensorize(s)
304 def batches(self, split = 'train'):
305 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
308 if args.data_size > 0:
309 data_iter = itertools.islice(data_iter, args.data_size)
311 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
313 def vocabulary_size(self):
314 return len(self.vocab)
316 def produce_results(self, n_epoch, model, nb_tokens = 50):
317 file_name = f'result_wiki103_{n_epoch:04d}.txt'
319 with open(file_name, 'w') as outfile:
321 'the cat is hunting a',
322 'paris is the capital',
323 'cars are convenient',
324 'the difference between men and women is',
325 'the object was blue all over and green all over it was',
326 'cherries are red and lemons are',
327 'cherries are sweet and lemons are',
328 'two plus three equals',
331 t_primer = self.tokenizer(primer)
334 for j in range(nb_tokens):
336 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
337 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
338 output = model(input)
339 logits = output[0, -1]
340 if args.synthesis_sampling:
341 dist = torch.distributions.categorical.Categorical(logits = logits)
345 t_generated.append(self.vocab.lookup_token(t))
346 if t_generated[-1] == '<non>': break
348 s = ' '.join(t_generated)
350 outfile.write(f'<{primer}> {s}\n')
352 log_string(f'wrote {file_name}')
354 ######################################################################
356 class TaskMNIST(Task):
358 def __init__(self, batch_size, device = torch.device('cpu')):
360 self.batch_size = batch_size
362 def batches(self, split = 'train'):
363 assert split in { 'train', 'test' }
364 data_set = torchvision.datasets.MNIST(
365 root = './data', train = (split == 'train'),
368 data_input = data_set.data.view(-1, 28 * 28).long()
369 if args.data_size >= 0:
370 data_input = data_input[:args.data_size]
371 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
374 def vocabulary_size(self):
377 def produce_results(self, n_epoch, model, nb_samples = 64):
378 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
379 for input in results.split(self.batch_size):
380 for s in tqdm.tqdm(range(input.size(1)), desc = 'synth'):
381 output = model(input)
382 logits = output[:, s]
383 if args.synthesis_sampling:
384 dist = torch.distributions.categorical.Categorical(logits = logits)
390 image_name = f'result_mnist_{n_epoch:04d}.png'
391 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
392 image_name, nrow = 16, pad_value = 0.8)
393 log_string(f'wrote {image_name}')
395 ######################################################################
397 log_string(f'device {device}')
399 if args.data == 'wiki103':
400 nb_epochs_default = 10
401 task = TaskWiki103(batch_size = args.batch_size, device = device)
402 elif args.data == 'mnist':
403 nb_epochs_default = 25
404 task = TaskMNIST(batch_size = args.batch_size, device = device)
405 elif args.data == 'picoclvr':
406 nb_epochs_default = 10
407 task = TaskPicoCLVR(batch_size = args.batch_size,
408 height = args.picoclvr_height,
409 width = args.picoclvr_width,
410 many_colors = args.picoclvr_many_colors,
413 raise ValueError(f'Unknown dataset {args.data}.')
415 vocabulary_size = task.vocabulary_size()
417 log_string(f'vocabulary_size {vocabulary_size}')
419 ##############################
422 vocabulary_size = vocabulary_size,
423 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
424 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
429 nb_parameters = sum(p.numel() for p in model.parameters())
430 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
432 ######################################################################
434 if args.optim == 'sgd':
435 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
436 elif args.optim == 'adam':
437 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
438 elif args.optim == 'adamw':
439 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
441 raise ValueError(f'Unknown optimizer {args.optim}.')
443 ######################################################################
445 nb_epochs_finished = 0
447 if args.no_checkpoint:
448 log_string(f'Not trying to load checkpoint.')
452 checkpoint = torch.load(args.checkpoint_name, map_location = device)
453 nb_epochs_finished = checkpoint['nb_epochs_finished']
454 model.load_state_dict(checkpoint['model_state'])
455 optimizer.load_state_dict(checkpoint['optimizer_state'])
456 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
458 except FileNotFoundError:
459 log_string('Starting from scratch.')
462 log_string('Error when loading the checkpoint.')
465 ######################################################################
467 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
470 for input in task.batches(split = 'train'):
471 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
472 token_probas = token_count / token_count.sum()
473 h = -torch.xlogy(token_probas, token_probas).sum()
474 train_set_perplexity = math.exp(h)
475 log_string(f'Train set perplexity {train_set_perplexity}')
477 for k in range(nb_epochs_finished, nb_epochs):
481 nb_train_samples, acc_train_loss = 0, 0.0
483 for input in task.batches(split = 'train'):
484 input = input.to(device)
485 output = model(input)
486 loss = F.cross_entropy(output.transpose(1, 2), input)
487 acc_train_loss += loss.item() * input.size(0)
488 nb_train_samples += input.size(0)
490 optimizer.zero_grad()
494 with torch.autograd.no_grad():
498 nb_test_samples, acc_test_loss = 0, 0.0
500 for input in task.batches(split = 'test'):
501 input = input.to(device)
502 output = model(input)
503 loss = F.cross_entropy(output.transpose(1, 2), input)
504 acc_test_loss += loss.item() * input.size(0)
505 nb_test_samples += input.size(0)
507 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
508 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
510 log_string(f'perplexity {k} train {train_perplexity} test {test_perplexity}')
512 task.produce_results(k, model)
515 'nb_epochs_finished': k + 1,
516 'model_state': model.state_dict(),
517 'optimizer_state': optimizer.state_dict()
520 torch.save(checkpoint, args.checkpoint_name)
522 ######################################################################