3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--no_checkpoint',
73 action='store_true', default = False)
75 parser.add_argument('--checkpoint_name',
76 type = str, default = 'checkpoint.pth')
78 ##############################
81 parser.add_argument('--picoclvr_many_colors',
82 action='store_true', default = False)
84 parser.add_argument('--picoclvr_height',
85 type = int, default = 12)
87 parser.add_argument('--picoclvr_width',
88 type = int, default = 16)
90 ######################################################################
92 args = parser.parse_args()
94 log_file = open(args.log_filename, 'w')
97 torch.manual_seed(args.seed)
99 ######################################################################
102 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
104 if log_file is not None:
105 log_file.write(t + s + '\n')
112 log_string(f'args.{n} {getattr(args, n)}')
114 ######################################################################
117 def batches(self, split = 'train'):
120 def vocabulary_size(self):
123 def produce_results(self, n_epoch, model, nb_tokens = 50):
126 ######################################################################
130 class TaskPicoCLVR(Task):
132 def __init__(self, batch_size,
133 height, width, many_colors = False,
134 device = torch.device('cpu')):
136 def generate_descr(nb):
137 descr = picoclvr.generate(
139 height = self.height, width = self.width,
140 many_colors = many_colors
143 descr = [ s.strip().split(' ') for s in descr ]
144 l = max([ len(s) for s in descr ])
145 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
151 self.batch_size = batch_size
153 nb = args.data_size if args.data_size > 0 else 250000
155 self.train_descr = generate_descr((nb * 4) // 5)
156 self.test_descr = generate_descr((nb * 1) // 5)
159 for d in [ self.train_descr, self.test_descr ]:
161 for t in s: tokens.add(t)
162 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
163 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
165 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
166 self.train_input = torch.tensor(t, device = self.device)
167 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
168 self.test_input = torch.tensor(t, device = self.device)
170 def batches(self, split = 'train'):
171 assert split in { 'train', 'test' }
173 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
176 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
179 def vocabulary_size(self):
180 return len(self.token2id)
182 def generate(self, primer, model, nb_tokens):
183 t_primer = primer.strip().split(' ')
186 for j in range(nb_tokens):
187 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
188 input = torch.tensor(t, device = self.device)
189 output = model(input)
190 logits = output[0, -1]
191 if args.synthesis_sampling:
192 dist = torch.distributions.categorical.Categorical(logits = logits)
196 t_generated.append(self.id2token[t.item()])
198 return ' '.join(t_primer + t_generated)
200 def produce_results(self, n_epoch, model, nb_tokens = None):
201 if nb_tokens is None:
202 nb_tokens = self.height * self.width + 3
207 'red above green <sep> green top <sep> blue right of red <img>',
208 'there is red <sep> there is yellow <sep> there is blue <img>',
209 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
210 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
213 for k in range(nb_per_primer):
214 descr.append(self.generate(primer, model, nb_tokens))
216 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
217 img = torch.cat(img, 0)
218 file_name = f'result_picoclvr_{n_epoch:04d}.png'
219 torchvision.utils.save_image(
221 file_name, nrow = nb_per_primer, pad_value = 0.8
223 log_string(f'wrote {file_name}')
226 x[2] for x in picoclvr.nb_missing_properties(
228 height = self.height, width = self.width
232 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
234 ######################################################################
236 class TaskWiki103(Task):
238 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
239 device = torch.device('cpu')):
241 self.batch_size = batch_size
242 self.len_min = len_min
243 self.len_max = len_max
244 self.min_freq = min_freq
247 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
248 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
251 if args.data_size > 0:
252 train_iter = itertools.islice(train_iter, args.data_size)
255 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
256 yield self.tokenizer(l)
258 self.vocab = torchtext.vocab.build_vocab_from_iterator(
260 specials = [ '<unk>', '<non>' ],
261 min_freq = self.min_freq
264 self.vocab.set_default_index(self.vocab[ '<unk>' ])
266 def tensorize(self, s):
267 a = max(len(x) for x in s)
268 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
270 def yield_batches(self, ds):
273 q = self.tokenizer(l)
274 if len(q) >= self.len_min and len(q) <= self.len_max:
276 if len(s) == self.batch_size:
277 yield self.tensorize(s)
281 yield self.tensorize(s)
283 def batches(self, split = 'train'):
284 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
287 if args.data_size > 0:
288 data_iter = itertools.islice(data_iter, args.data_size)
290 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
292 def vocabulary_size(self):
293 return len(self.vocab)
295 def produce_results(self, n_epoch, model, nb_tokens = 50):
296 file_name = f'result_wiki103_{n_epoch:04d}.txt'
298 with open(file_name, 'w') as outfile:
300 'the cat is hunting a',
301 'paris is the capital',
302 'cars are convenient',
303 'the difference between men and women is',
304 'the object was blue all over and green all over it was',
305 'cherries are red and lemons are',
306 'cherries are sweet and lemons are',
307 'two plus three equals',
310 t_primer = self.tokenizer(primer)
313 for j in range(nb_tokens):
315 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
316 output = model(input)
317 logits = output[0, -1]
318 if args.synthesis_sampling:
319 dist = torch.distributions.categorical.Categorical(logits = logits)
323 t_generated.append(self.vocab.lookup_token(t))
324 if t_generated[-1] == '<non>': break
326 s = ' '.join(t_generated)
328 outfile.write(f'<{primer}> {s}\n')
330 log_string(f'wrote {file_name}')
332 ######################################################################
334 class TaskMNIST(Task):
336 def __init__(self, batch_size, device = torch.device('cpu')):
338 self.batch_size = batch_size
340 def batches(self, split = 'train'):
341 assert split in { 'train', 'test' }
342 data_set = torchvision.datasets.MNIST(
343 root = './data', train = (split == 'train'),
346 data_input = data_set.data.view(-1, 28 * 28).long()
347 if args.data_size >= 0:
348 data_input = data_input[:args.data_size]
349 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
352 def vocabulary_size(self):
355 def produce_results(self, n_epoch, model, nb_samples = 64):
356 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
357 for input in results.split(self.batch_size):
358 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
359 output = model(input)
360 logits = output[:, s]
361 if args.synthesis_sampling:
362 dist = torch.distributions.categorical.Categorical(logits = logits)
368 image_name = f'result_mnist_{n_epoch:04d}.png'
369 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
370 image_name, nrow = 16, pad_value = 0.8)
371 log_string(f'wrote {image_name}')
373 ######################################################################
375 def check_causality(model):
377 input = torch.rand(1, 5, dim_model).requires_grad_()
379 a = torch.zeros(output.size(1), input.size(1))
380 for k in range(output.size(1)):
381 for d in range(output.size(2)):
382 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
383 a[k] += g.squeeze(0).pow(2).sum(1)
386 ######################################################################
388 log_string(f'device {device}')
390 if args.data == 'wiki103':
391 task = TaskWiki103(batch_size = args.batch_size, device = device)
392 elif args.data == 'mnist':
393 task = TaskMNIST(batch_size = args.batch_size, device = device)
394 elif args.data == 'picoclvr':
395 task = TaskPicoCLVR(batch_size = args.batch_size,
396 height = args.picoclvr_height,
397 width = args.picoclvr_width,
398 many_colors = args.picoclvr_many_colors,
401 raise ValueError(f'Unknown dataset {args.data}.')
403 vocabulary_size = task.vocabulary_size()
405 log_string(f'vocabulary_size {vocabulary_size}')
407 ##############################
410 vocabulary_size = vocabulary_size,
411 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
412 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
417 nb_parameters = sum(p.numel() for p in model.parameters())
418 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
420 ######################################################################
422 if args.optim == 'sgd':
423 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
424 elif args.optim == 'adam':
425 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
426 elif args.optim == 'adamw':
427 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
429 raise ValueError(f'Unknown optimizer {args.optim}.')
431 ######################################################################
433 nb_epochs_finished = 0
435 if args.no_checkpoint:
436 log_string(f'Not trying to load checkpoint.')
440 checkpoint = torch.load(args.checkpoint_name, map_location = device)
441 nb_epochs_finished = checkpoint['nb_epochs_finished']
442 model.load_state_dict(checkpoint['model_state'])
443 optimizer.load_state_dict(checkpoint['optimizer_state'])
444 log_string(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
446 except FileNotFoundError:
447 log_string('Starting from scratch.')
450 log_string('Error when loading the checkpoint.')
453 ######################################################################
455 for k in range(nb_epochs_finished, args.nb_epochs):
459 nb_train_samples, acc_train_loss = 0, 0.0
461 for input in task.batches(split = 'train'):
462 input = input.to(device)
463 output = model(input)
464 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
465 acc_train_loss += loss.item() * input.size(0)
466 nb_train_samples += input.size(0)
468 optimizer.zero_grad()
472 with torch.autograd.no_grad():
476 nb_test_samples, acc_test_loss = 0, 0.0
478 for input in task.batches(split = 'test'):
479 input = input.to(device)
480 output = model(input)
481 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
482 acc_test_loss += loss.item() * input.size(0)
483 nb_test_samples += input.size(0)
485 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
486 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
488 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
490 task.produce_results(k, model)
493 'nb_epochs_finished': k + 1,
494 'model_state': model.state_dict(),
495 'optimizer_state': optimizer.state_dict()
498 torch.save(checkpoint, args.checkpoint_name)
500 ######################################################################