3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ##############################
78 parser.add_argument('--picoclvr_many_colors',
79 action='store_true', default = False)
81 parser.add_argument('--picoclvr_height',
82 type = int, default = 12)
84 parser.add_argument('--picoclvr_width',
85 type = int, default = 16)
87 ######################################################################
89 args = parser.parse_args()
91 log_file = open(args.log_filename, 'w')
94 torch.manual_seed(args.seed)
96 ######################################################################
99 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
101 if log_file is not None:
102 log_file.write(t + s + '\n')
109 log_string(f'args.{n} {getattr(args, n)}')
111 ######################################################################
114 def batches(self, split = 'train'):
117 def vocabulary_size(self):
120 def produce_results(self, n_epoch, model, nb_tokens = 50):
123 ######################################################################
127 class TaskPicoCLVR(Task):
129 def __init__(self, batch_size,
130 height, width, many_colors = False,
131 device = torch.device('cpu')):
135 self.batch_size = batch_size
137 nb = args.data_size if args.data_size > 0 else 250000
139 descr = picoclvr.generate(
141 height = self.height, width = self.width,
142 many_colors = many_colors
145 # self.test_descr = descr[:nb // 5]
146 # self.train_descr = descr[nb // 5:]
148 descr = [ s.strip().split(' ') for s in descr ]
149 l = max([ len(s) for s in descr ])
150 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
154 for t in s: tokens.add(t)
155 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
156 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
158 t = [ [ self.token2id[u] for u in s ] for s in descr ]
159 data_input = torch.tensor(t, device = self.device)
161 self.test_input = data_input[:nb // 5]
162 self.train_input = data_input[nb // 5:]
164 def batches(self, split = 'train'):
165 assert split in { 'train', 'test' }
167 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
170 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
173 def vocabulary_size(self):
174 return len(self.token2id)
176 def generate(self, primer, model, nb_tokens):
177 t_primer = primer.strip().split(' ')
180 for j in range(nb_tokens):
181 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
182 input = torch.tensor(t, device = self.device)
183 output = model(input)
184 logits = output[0, -1]
185 if args.synthesis_sampling:
186 dist = torch.distributions.categorical.Categorical(logits = logits)
190 t_generated.append(self.id2token[t.item()])
192 return ' '.join(t_primer + t_generated)
194 def produce_results(self, n_epoch, model, nb_tokens = None):
195 if nb_tokens is None:
196 nb_tokens = self.height * self.width + 3
201 'red above green <sep> green top <sep> blue right of red <img>',
202 'there is red <sep> there is yellow <sep> there is blue <img>',
203 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
204 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
207 for k in range(nb_per_primer):
208 descr.append(self.generate(primer, model, nb_tokens))
210 img = [ picoclvr.descr2img(d, height = self.height, width = self.width) for d in descr ]
211 img = torch.cat(img, 0)
212 file_name = f'result_picoclvr_{n_epoch:04d}.png'
213 torchvision.utils.save_image(
215 file_name, nrow = nb_per_primer, pad_value = 0.8
217 log_string(f'wrote {file_name}')
220 x[2] for x in picoclvr.nb_missing_properties(
222 height = self.height, width = self.width
226 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
228 ######################################################################
230 class TaskWiki103(Task):
232 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
233 device = torch.device('cpu')):
235 self.batch_size = batch_size
236 self.len_min = len_min
237 self.len_max = len_max
238 self.min_freq = min_freq
241 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
242 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
245 if args.data_size > 0:
246 train_iter = itertools.islice(train_iter, args.data_size)
249 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
250 yield self.tokenizer(l)
252 self.vocab = torchtext.vocab.build_vocab_from_iterator(
254 specials = [ '<unk>', '<non>' ],
255 min_freq = self.min_freq
258 self.vocab.set_default_index(self.vocab[ '<unk>' ])
260 def tensorize(self, s):
261 a = max(len(x) for x in s)
262 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
264 def yield_batches(self, ds):
267 q = self.tokenizer(l)
268 if len(q) >= self.len_min and len(q) <= self.len_max:
270 if len(s) == self.batch_size:
271 yield self.tensorize(s)
275 yield self.tensorize(s)
277 def batches(self, split = 'train'):
278 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
281 if args.data_size > 0:
282 data_iter = itertools.islice(data_iter, args.data_size)
284 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
286 def vocabulary_size(self):
287 return len(self.vocab)
289 def produce_results(self, n_epoch, model, nb_tokens = 50):
290 file_name = f'result_wiki103_{n_epoch:04d}.txt'
292 with open(file_name, 'w') as outfile:
294 'the cat is hunting a',
295 'paris is the capital',
296 'cars are convenient',
297 'the difference between men and women is',
298 'the object was blue all over and green all over it was',
299 'cherries are red and lemons are',
300 'cherries are sweet and lemons are',
301 'two plus three equals',
304 t_primer = self.tokenizer(primer)
307 for j in range(nb_tokens):
309 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
310 output = model(input)
311 logits = output[0, -1]
312 if args.synthesis_sampling:
313 dist = torch.distributions.categorical.Categorical(logits = logits)
317 t_generated.append(self.vocab.lookup_token(t))
318 if t_generated[-1] == '<non>': break
320 s = ' '.join(t_generated)
322 outfile.write(f'<{primer}> {s}\n')
324 log_string(f'wrote {file_name}')
326 ######################################################################
328 class TaskMNIST(Task):
330 def __init__(self, batch_size, device = torch.device('cpu')):
332 self.batch_size = batch_size
334 def batches(self, split = 'train'):
335 assert split in { 'train', 'test' }
336 data_set = torchvision.datasets.MNIST(
337 root = './data', train = (split == 'train'),
340 data_input = data_set.data.view(-1, 28 * 28).long()
341 if args.data_size >= 0:
342 data_input = data_input[:args.data_size]
343 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
346 def vocabulary_size(self):
349 def produce_results(self, n_epoch, model, nb_samples = 64):
350 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
351 for input in results.split(self.batch_size):
352 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
353 output = model(input)
354 logits = output[:, s]
355 if args.synthesis_sampling:
356 dist = torch.distributions.categorical.Categorical(logits = logits)
362 image_name = f'result_mnist_{n_epoch:04d}.png'
363 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
364 image_name, nrow = 16, pad_value = 0.8)
365 log_string(f'wrote {image_name}')
367 ######################################################################
369 def check_causality(model):
371 input = torch.rand(1, 5, dim_model).requires_grad_()
373 a = torch.zeros(output.size(1), input.size(1))
374 for k in range(output.size(1)):
375 for d in range(output.size(2)):
376 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
377 a[k] += g.squeeze(0).pow(2).sum(1)
380 ######################################################################
382 log_string(f'device {device}')
384 if args.data == 'wiki103':
385 task = TaskWiki103(batch_size = args.batch_size, device = device)
386 elif args.data == 'mnist':
387 task = TaskMNIST(batch_size = args.batch_size, device = device)
388 elif args.data == 'picoclvr':
389 task = TaskPicoCLVR(batch_size = args.batch_size,
390 height = args.picoclvr_height,
391 width = args.picoclvr_width,
392 many_colors = args.picoclvr_many_colors,
395 raise ValueError(f'Unknown dataset {args.data}.')
397 vocabulary_size = task.vocabulary_size()
399 log_string(f'vocabulary_size {vocabulary_size}')
401 ##############################
404 vocabulary_size = vocabulary_size,
405 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
406 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
411 nb_parameters = sum(p.numel() for p in model.parameters())
412 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
414 ######################################################################
416 if args.optim == 'sgd':
417 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
418 elif args.optim == 'adam':
419 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
420 elif args.optim == 'adamw':
421 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
423 raise ValueError(f'Unknown optimizer {args.optim}.')
425 ######################################################################
427 nb_epochs_finished = 0
430 checkpoint = torch.load(args.checkpoint_name, map_location = device)
431 nb_epochs_finished = checkpoint['nb_epochs_finished']
432 model.load_state_dict(checkpoint['model_state'])
433 optimizer.load_state_dict(checkpoint['optimizer_state'])
434 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
436 except FileNotFoundError:
437 print('Starting from scratch.')
440 print('Error when loading the checkpoint.')
443 ######################################################################
445 for k in range(nb_epochs_finished, args.nb_epochs):
449 nb_train_samples, acc_train_loss = 0, 0.0
451 for input in task.batches(split = 'train'):
452 input = input.to(device)
453 output = model(input)
454 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
455 acc_train_loss += loss.item() * input.size(0)
456 nb_train_samples += input.size(0)
458 optimizer.zero_grad()
462 with torch.autograd.no_grad():
466 nb_test_samples, acc_test_loss = 0, 0.0
468 for input in task.batches(split = 'test'):
469 input = input.to(device)
470 output = model(input)
471 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
472 acc_test_loss += loss.item() * input.size(0)
473 nb_test_samples += input.size(0)
475 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
476 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
478 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
480 task.produce_results(k, model)
483 'nb_epochs_finished': k + 1,
484 'model_state': model.state_dict(),
485 'optimizer_state': optimizer.state_dict()
488 torch.save(checkpoint, args.checkpoint_name)
490 ######################################################################