3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_many_colors',
79 action='store_true', default = False)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
114 def batches(self, split = 'train'):
117 def vocabulary_size(self):
120 def produce_results(self, n_epoch, model, nb_tokens = 50):
123 ######################################################################
127 class TaskPicoCLVR(Task):
129 def __init__(self, batch_size,
130 height, width, many_colors = False,
131 device = torch.device('cpu')):
135 self.batch_size = batch_size
137 nb = args.data_size if args.data_size > 0 else 250000
139 descr = picoclvr.generate(
141 height = self.height, width = self.width,
142 many_colors = many_colors
145 # self.test_descr = descr[:nb // 5]
146 # self.train_descr = descr[nb // 5:]
148 descr = [ s.strip().split(' ') for s in descr ]
149 l = max([ len(s) for s in descr ])
150 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
154 for t in s: tokens.add(t)
155 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
156 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
158 t = [ [ self.token2id[u] for u in s ] for s in descr ]
159 data_input = torch.tensor(t, device = self.device)
161 self.test_input = data_input[:nb // 5]
162 self.train_input = data_input[nb // 5:]
164 def batches(self, split = 'train'):
165 assert split in { 'train', 'test' }
167 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
170 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
173 def vocabulary_size(self):
174 return len(self.token2id)
176 def generate(self, primer, model, nb_tokens):
177 t_primer = primer.strip().split(' ')
180 for j in range(nb_tokens):
181 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
182 input = torch.tensor(t, device = self.device)
183 output = model(input)
184 logits = output[0, -1]
185 if args.synthesis_sampling:
186 dist = torch.distributions.categorical.Categorical(logits = logits)
190 t_generated.append(self.id2token[t.item()])
192 return ' '.join(t_primer + t_generated)
194 def produce_results(self, n_epoch, model, nb_tokens = None):
195 if nb_tokens is None:
196 nb_tokens = self.height * self.width + 3
201 'red above green <sep> green top <sep> blue right of red <img>',
202 'there is red <sep> there is yellow <sep> there is blue <img>',
203 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
204 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
207 for k in range(nb_per_primer):
208 descr.append(self.generate(primer, model, nb_tokens))
210 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
211 img = torch.cat(img, 0)
212 file_name = f'result_picoclvr_{n_epoch:04d}.png'
213 torchvision.utils.save_image(img / 255.,
214 file_name, nrow = nb_per_primer, pad_value = 0.8)
215 log_string(f'wrote {file_name}')
217 nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr, height = self.height, width = self.width) ] )
218 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
220 ######################################################################
222 class TaskWiki103(Task):
224 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
225 device = torch.device('cpu')):
227 self.batch_size = batch_size
228 self.len_min = len_min
229 self.len_max = len_max
230 self.min_freq = min_freq
233 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
234 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
237 if args.data_size > 0:
238 train_iter = itertools.islice(train_iter, args.data_size)
241 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
242 yield self.tokenizer(l)
244 self.vocab = torchtext.vocab.build_vocab_from_iterator(
246 specials = [ '<unk>', '<non>' ],
247 min_freq = self.min_freq
250 self.vocab.set_default_index(self.vocab[ '<unk>' ])
252 def tensorize(self, s):
253 a = max(len(x) for x in s)
254 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
256 def yield_batches(self, ds):
259 q = self.tokenizer(l)
260 if len(q) >= self.len_min and len(q) <= self.len_max:
262 if len(s) == self.batch_size:
263 yield self.tensorize(s)
267 yield self.tensorize(s)
269 def batches(self, split = 'train'):
270 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
273 if args.data_size > 0:
274 data_iter = itertools.islice(data_iter, args.data_size)
276 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
278 def vocabulary_size(self):
279 return len(self.vocab)
281 def produce_results(self, n_epoch, model, nb_tokens = 50):
282 file_name = f'result_wiki103_{n_epoch:04d}.txt'
284 with open(file_name, 'w') as outfile:
286 'the cat is hunting a',
287 'paris is the capital',
288 'cars are convenient',
289 'the difference between men and women is',
290 'the object was blue all over and green all over it was',
291 'cherries are red and lemons are',
292 'cherries are sweet and lemons are',
293 'two plus three equals',
296 t_primer = self.tokenizer(primer)
299 for j in range(nb_tokens):
301 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
302 output = model(input)
303 logits = output[0, -1]
304 if args.synthesis_sampling:
305 dist = torch.distributions.categorical.Categorical(logits = logits)
309 t_generated.append(self.vocab.lookup_token(t))
310 if t_generated[-1] == '<non>': break
312 s = ' '.join(t_generated)
314 outfile.write(f'<{primer}> {s}\n')
316 log_string(f'wrote {file_name}')
318 ######################################################################
320 class TaskMNIST(Task):
322 def __init__(self, batch_size, device = torch.device('cpu')):
324 self.batch_size = batch_size
326 def batches(self, split = 'train'):
327 assert split in { 'train', 'test' }
328 data_set = torchvision.datasets.MNIST(
329 root = './data', train = (split == 'train'),
332 data_input = data_set.data.view(-1, 28 * 28).long()
333 if args.data_size >= 0:
334 data_input = data_input[:args.data_size]
335 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
338 def vocabulary_size(self):
341 def produce_results(self, n_epoch, model, nb_samples = 64):
342 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
343 for input in results.split(self.batch_size):
344 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
345 output = model(input)
346 logits = output[:, s]
347 if args.synthesis_sampling:
348 dist = torch.distributions.categorical.Categorical(logits = logits)
354 image_name = f'result_mnist_{n_epoch:04d}.png'
355 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
356 image_name, nrow = 16, pad_value = 0.8)
357 log_string(f'wrote {image_name}')
359 ######################################################################
361 def check_causality(model):
363 input = torch.rand(1, 5, dim_model).requires_grad_()
365 a = torch.zeros(output.size(1), input.size(1))
366 for k in range(output.size(1)):
367 for d in range(output.size(2)):
368 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
369 a[k] += g.squeeze(0).pow(2).sum(1)
372 ######################################################################
374 log_string(f'device {device}')
376 if args.data == 'wiki103':
377 task = TaskWiki103(batch_size = args.batch_size, device = device)
378 elif args.data == 'mnist':
379 task = TaskMNIST(batch_size = args.batch_size, device = device)
380 elif args.data == 'picoclvr':
381 task = TaskPicoCLVR(batch_size = args.batch_size,
382 height = args.picoclvr_height,
383 width = args.picoclvr_width,
384 many_colors = args.picoclvr_many_colors,
387 raise ValueError(f'Unknown dataset {args.data}.')
389 vocabulary_size = task.vocabulary_size()
391 log_string(f'vocabulary_size {vocabulary_size}')
393 ##############################
396 vocabulary_size = vocabulary_size,
397 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
398 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
403 nb_parameters = sum(p.numel() for p in model.parameters())
404 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
406 ######################################################################
408 if args.optim == 'sgd':
409 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
410 elif args.optim == 'adam':
411 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
412 elif args.optim == 'adamw':
413 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
415 raise ValueError(f'Unknown optimizer {args.optim}.')
417 ######################################################################
419 nb_epochs_finished = 0
422 checkpoint = torch.load(args.checkpoint_name, map_location = device)
423 nb_epochs_finished = checkpoint['nb_epochs_finished']
424 model.load_state_dict(checkpoint['model_state'])
425 optimizer.load_state_dict(checkpoint['optimizer_state'])
426 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
428 except FileNotFoundError:
429 print('Starting from scratch.')
432 print('Error when loading the checkpoint.')
435 ######################################################################
437 for k in range(nb_epochs_finished, args.nb_epochs):
441 nb_train_samples, acc_train_loss = 0, 0.0
443 for input in task.batches(split = 'train'):
444 input = input.to(device)
445 output = model(input)
446 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
447 acc_train_loss += loss.item() * input.size(0)
448 nb_train_samples += input.size(0)
450 optimizer.zero_grad()
454 with torch.autograd.no_grad():
458 nb_test_samples, acc_test_loss = 0, 0.0
460 for input in task.batches(split = 'test'):
461 input = input.to(device)
462 output = model(input)
463 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
464 acc_test_loss += loss.item() * input.size(0)
465 nb_test_samples += input.size(0)
467 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
468 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
470 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
472 task.produce_results(k, model)
475 'nb_epochs_finished': k + 1,
476 'model_state': model.state_dict(),
477 'optimizer_state': optimizer.state_dict()
480 torch.save(checkpoint, args.checkpoint_name)
482 ######################################################################