3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_many_colors',
79 action='store_true', default = False)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
114 def batches(self, split = 'train'):
117 def vocabulary_size(self):
120 def produce_results(self, n_epoch, model, nb_tokens = 50):
123 ######################################################################
127 class TaskPicoCLVR(Task):
129 def __init__(self, batch_size,
130 height, width, many_colors = False,
131 device = torch.device('cpu')):
133 def generate_descr(nb):
134 descr = picoclvr.generate(
136 height = self.height, width = self.width,
137 many_colors = many_colors
140 descr = [ s.strip().split(' ') for s in descr ]
141 l = max([ len(s) for s in descr ])
142 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
148 self.batch_size = batch_size
150 nb = args.data_size if args.data_size > 0 else 250000
152 self.train_descr = generate_descr((nb * 4) // 5)
153 self.test_descr = generate_descr((nb * 1) // 5)
156 for d in [ self.train_descr, self.test_descr ]:
158 for t in s: tokens.add(t)
159 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
160 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
162 t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ]
163 self.train_input = torch.tensor(t, device = self.device)
164 t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ]
165 self.test_input = torch.tensor(t, device = self.device)
167 def batches(self, split = 'train'):
168 assert split in { 'train', 'test' }
170 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
173 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
176 def vocabulary_size(self):
177 return len(self.token2id)
179 def generate(self, primer, model, nb_tokens):
180 t_primer = primer.strip().split(' ')
183 for j in range(nb_tokens):
184 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
185 input = torch.tensor(t, device = self.device)
186 output = model(input)
187 logits = output[0, -1]
188 if args.synthesis_sampling:
189 dist = torch.distributions.categorical.Categorical(logits = logits)
193 t_generated.append(self.id2token[t.item()])
195 return ' '.join(t_primer + t_generated)
197 def produce_results(self, n_epoch, model, nb_tokens = None):
198 if nb_tokens is None:
199 nb_tokens = self.height * self.width + 3
204 'red above green <sep> green top <sep> blue right of red <img>',
205 'there is red <sep> there is yellow <sep> there is blue <img>',
206 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
207 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
210 for k in range(nb_per_primer):
211 descr.append(self.generate(primer, model, nb_tokens))
213 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
214 img = torch.cat(img, 0)
215 file_name = f'result_picoclvr_{n_epoch:04d}.png'
216 torchvision.utils.save_image(
218 file_name, nrow = nb_per_primer, pad_value = 0.8
220 log_string(f'wrote {file_name}')
223 x[2] for x in picoclvr.nb_missing_properties(
225 height = self.height, width = self.width
229 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
231 ######################################################################
233 class TaskWiki103(Task):
235 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
236 device = torch.device('cpu')):
238 self.batch_size = batch_size
239 self.len_min = len_min
240 self.len_max = len_max
241 self.min_freq = min_freq
244 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
245 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
248 if args.data_size > 0:
249 train_iter = itertools.islice(train_iter, args.data_size)
252 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
253 yield self.tokenizer(l)
255 self.vocab = torchtext.vocab.build_vocab_from_iterator(
257 specials = [ '<unk>', '<non>' ],
258 min_freq = self.min_freq
261 self.vocab.set_default_index(self.vocab[ '<unk>' ])
263 def tensorize(self, s):
264 a = max(len(x) for x in s)
265 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
267 def yield_batches(self, ds):
270 q = self.tokenizer(l)
271 if len(q) >= self.len_min and len(q) <= self.len_max:
273 if len(s) == self.batch_size:
274 yield self.tensorize(s)
278 yield self.tensorize(s)
280 def batches(self, split = 'train'):
281 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
284 if args.data_size > 0:
285 data_iter = itertools.islice(data_iter, args.data_size)
287 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
289 def vocabulary_size(self):
290 return len(self.vocab)
292 def produce_results(self, n_epoch, model, nb_tokens = 50):
293 file_name = f'result_wiki103_{n_epoch:04d}.txt'
295 with open(file_name, 'w') as outfile:
297 'the cat is hunting a',
298 'paris is the capital',
299 'cars are convenient',
300 'the difference between men and women is',
301 'the object was blue all over and green all over it was',
302 'cherries are red and lemons are',
303 'cherries are sweet and lemons are',
304 'two plus three equals',
307 t_primer = self.tokenizer(primer)
310 for j in range(nb_tokens):
312 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
313 output = model(input)
314 logits = output[0, -1]
315 if args.synthesis_sampling:
316 dist = torch.distributions.categorical.Categorical(logits = logits)
320 t_generated.append(self.vocab.lookup_token(t))
321 if t_generated[-1] == '<non>': break
323 s = ' '.join(t_generated)
325 outfile.write(f'<{primer}> {s}\n')
327 log_string(f'wrote {file_name}')
329 ######################################################################
331 class TaskMNIST(Task):
333 def __init__(self, batch_size, device = torch.device('cpu')):
335 self.batch_size = batch_size
337 def batches(self, split = 'train'):
338 assert split in { 'train', 'test' }
339 data_set = torchvision.datasets.MNIST(
340 root = './data', train = (split == 'train'),
343 data_input = data_set.data.view(-1, 28 * 28).long()
344 if args.data_size >= 0:
345 data_input = data_input[:args.data_size]
346 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
349 def vocabulary_size(self):
352 def produce_results(self, n_epoch, model, nb_samples = 64):
353 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
354 for input in results.split(self.batch_size):
355 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
356 output = model(input)
357 logits = output[:, s]
358 if args.synthesis_sampling:
359 dist = torch.distributions.categorical.Categorical(logits = logits)
365 image_name = f'result_mnist_{n_epoch:04d}.png'
366 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
367 image_name, nrow = 16, pad_value = 0.8)
368 log_string(f'wrote {image_name}')
370 ######################################################################
372 def check_causality(model):
374 input = torch.rand(1, 5, dim_model).requires_grad_()
376 a = torch.zeros(output.size(1), input.size(1))
377 for k in range(output.size(1)):
378 for d in range(output.size(2)):
379 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
380 a[k] += g.squeeze(0).pow(2).sum(1)
383 ######################################################################
385 log_string(f'device {device}')
387 if args.data == 'wiki103':
388 task = TaskWiki103(batch_size = args.batch_size, device = device)
389 elif args.data == 'mnist':
390 task = TaskMNIST(batch_size = args.batch_size, device = device)
391 elif args.data == 'picoclvr':
392 task = TaskPicoCLVR(batch_size = args.batch_size,
393 height = args.picoclvr_height,
394 width = args.picoclvr_width,
395 many_colors = args.picoclvr_many_colors,
398 raise ValueError(f'Unknown dataset {args.data}.')
400 vocabulary_size = task.vocabulary_size()
402 log_string(f'vocabulary_size {vocabulary_size}')
404 ##############################
407 vocabulary_size = vocabulary_size,
408 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
409 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
414 nb_parameters = sum(p.numel() for p in model.parameters())
415 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
417 ######################################################################
419 if args.optim == 'sgd':
420 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
421 elif args.optim == 'adam':
422 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
423 elif args.optim == 'adamw':
424 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
426 raise ValueError(f'Unknown optimizer {args.optim}.')
428 ######################################################################
430 nb_epochs_finished = 0
433 checkpoint = torch.load(args.checkpoint_name, map_location = device)
434 nb_epochs_finished = checkpoint['nb_epochs_finished']
435 model.load_state_dict(checkpoint['model_state'])
436 optimizer.load_state_dict(checkpoint['optimizer_state'])
437 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
439 except FileNotFoundError:
440 print('Starting from scratch.')
443 print('Error when loading the checkpoint.')
446 ######################################################################
448 for k in range(nb_epochs_finished, args.nb_epochs):
452 nb_train_samples, acc_train_loss = 0, 0.0
454 for input in task.batches(split = 'train'):
455 input = input.to(device)
456 output = model(input)
457 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
458 acc_train_loss += loss.item() * input.size(0)
459 nb_train_samples += input.size(0)
461 optimizer.zero_grad()
465 with torch.autograd.no_grad():
469 nb_test_samples, acc_test_loss = 0, 0.0
471 for input in task.batches(split = 'test'):
472 input = input.to(device)
473 output = model(input)
474 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
475 acc_test_loss += loss.item() * input.size(0)
476 nb_test_samples += input.size(0)
478 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
479 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
481 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
483 task.produce_results(k, model)
486 'nb_epochs_finished': k + 1,
487 'model_state': model.state_dict(),
488 'optimizer_state': optimizer.state_dict()
491 torch.save(checkpoint, args.checkpoint_name)
493 ######################################################################