3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 parser.add_argument('--picoclvr_many_colors',
76 action='store_true', default = False)
78 ######################################################################
80 args = parser.parse_args()
82 log_file = open(args.log_filename, 'w')
85 torch.manual_seed(args.seed)
87 ######################################################################
90 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
92 if log_file is not None:
93 log_file.write(t + s + '\n')
100 log_string(f'args.{n} {getattr(args, n)}')
102 ######################################################################
105 def batches(self, split = 'train'):
108 def vocabulary_size(self):
111 def produce_results(self, n_epoch, model, nb_tokens = 50):
114 ######################################################################
118 class TaskPicoCLVR(Task):
120 def __init__(self, batch_size,
121 height = 6, width = 8, many_colors = False,
122 device = torch.device('cpu')):
124 self.batch_size = batch_size
126 nb = args.data_size if args.data_size > 0 else 250000
128 descr = picoclvr.generate(
130 height = height, width = width,
131 many_colors = many_colors
134 # self.test_descr = descr[:nb // 5]
135 # self.train_descr = descr[nb // 5:]
137 descr = [ s.strip().split(' ') for s in descr ]
138 l = max([ len(s) for s in descr ])
139 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
143 for t in s: tokens.add(t)
144 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
145 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
147 t = [ [ self.token2id[u] for u in s ] for s in descr ]
148 data_input = torch.tensor(t, device = self.device)
150 self.test_input = data_input[:nb // 5]
151 self.train_input = data_input[nb // 5:]
153 def batches(self, split = 'train'):
154 assert split in { 'train', 'test' }
156 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
159 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
162 def vocabulary_size(self):
163 return len(self.token2id)
165 def generate(self, primer, model, nb_tokens):
166 t_primer = primer.strip().split(' ')
169 for j in range(nb_tokens):
170 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
171 input = torch.tensor(t, device = self.device)
172 output = model(input)
173 logits = output[0, -1]
174 if args.synthesis_sampling:
175 dist = torch.distributions.categorical.Categorical(logits = logits)
179 t_generated.append(self.id2token[t.item()])
181 return ' '.join(t_primer + t_generated)
183 def produce_results(self, n_epoch, model, nb_tokens = 50):
188 'red above green <sep> green top <sep> blue right of red <img>',
189 'there is red <sep> there is yellow <sep> there is blue <img>',
190 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
191 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
194 for k in range(nb_per_primer):
195 descr.append(self.generate(primer, model, nb_tokens))
197 img = [ picoclvr.descr2img(d) for d in descr ]
198 img = torch.cat(img, 0)
199 file_name = f'result_picoclvr_{n_epoch:04d}.png'
200 torchvision.utils.save_image(img / 255.,
201 file_name, nrow = nb_per_primer, pad_value = 0.8)
202 log_string(f'wrote {file_name}')
204 nb_missing = sum( [ x[2] for x in picoclvr.nb_missing_properties(descr) ] )
205 log_string(f'nb_missing {nb_missing / len(descr):.02f}')
207 ######################################################################
209 class TaskWiki103(Task):
211 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
212 device = torch.device('cpu')):
214 self.batch_size = batch_size
215 self.len_min = len_min
216 self.len_max = len_max
217 self.min_freq = min_freq
220 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
221 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
224 if args.data_size > 0:
225 train_iter = itertools.islice(train_iter, args.data_size)
228 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
229 yield self.tokenizer(l)
231 self.vocab = torchtext.vocab.build_vocab_from_iterator(
233 specials = [ '<unk>', '<non>' ],
234 min_freq = self.min_freq
237 self.vocab.set_default_index(self.vocab[ '<unk>' ])
239 def tensorize(self, s):
240 a = max(len(x) for x in s)
241 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
243 def yield_batches(self, ds):
246 q = self.tokenizer(l)
247 if len(q) >= self.len_min and len(q) <= self.len_max:
249 if len(s) == self.batch_size:
250 yield self.tensorize(s)
254 yield self.tensorize(s)
256 def batches(self, split = 'train'):
257 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
260 if args.data_size > 0:
261 data_iter = itertools.islice(data_iter, args.data_size)
263 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
265 def vocabulary_size(self):
266 return len(self.vocab)
268 def produce_results(self, n_epoch, model, nb_tokens = 50):
269 file_name = f'result_wiki103_{n_epoch:04d}.txt'
271 with open(file_name, 'w') as outfile:
273 'the cat is hunting a',
274 'paris is the capital',
275 'cars are convenient',
276 'the difference between men and women is',
277 'the object was blue all over and green all over it was',
278 'cherries are red and lemons are',
279 'cherries are sweet and lemons are',
280 'two plus three equals',
283 t_primer = self.tokenizer(primer)
286 for j in range(nb_tokens):
288 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
289 output = model(input)
290 logits = output[0, -1]
291 if args.synthesis_sampling:
292 dist = torch.distributions.categorical.Categorical(logits = logits)
296 t_generated.append(self.vocab.lookup_token(t))
297 if t_generated[-1] == '<non>': break
299 s = ' '.join(t_generated)
301 outfile.write(f'<{primer}> {s}\n')
303 log_string(f'wrote {file_name}')
305 ######################################################################
307 class TaskMNIST(Task):
309 def __init__(self, batch_size, device = torch.device('cpu')):
311 self.batch_size = batch_size
313 def batches(self, split = 'train'):
314 assert split in { 'train', 'test' }
315 data_set = torchvision.datasets.MNIST(
316 root = './data', train = (split == 'train'),
319 data_input = data_set.data.view(-1, 28 * 28).long()
320 if args.data_size >= 0:
321 data_input = data_input[:args.data_size]
322 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
325 def vocabulary_size(self):
328 def produce_results(self, n_epoch, model, nb_samples = 64):
329 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
330 for input in results.split(self.batch_size):
331 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
332 output = model(input)
333 logits = output[:, s]
334 if args.synthesis_sampling:
335 dist = torch.distributions.categorical.Categorical(logits = logits)
341 image_name = f'result_mnist_{n_epoch:04d}.png'
342 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
343 image_name, nrow = 16, pad_value = 0.8)
344 log_string(f'wrote {image_name}')
346 ######################################################################
348 def check_causality(model):
350 input = torch.rand(1, 5, dim_model).requires_grad_()
352 a = torch.zeros(output.size(1), input.size(1))
353 for k in range(output.size(1)):
354 for d in range(output.size(2)):
355 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
356 a[k] += g.squeeze(0).pow(2).sum(1)
359 ######################################################################
361 log_string(f'device {device}')
363 if args.data == 'wiki103':
364 task = TaskWiki103(batch_size = args.batch_size, device = device)
365 elif args.data == 'mnist':
366 task = TaskMNIST(batch_size = args.batch_size, device = device)
367 elif args.data == 'picoclvr':
368 task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
370 raise ValueError(f'Unknown dataset {args.data}.')
372 vocabulary_size = task.vocabulary_size()
374 log_string(f'vocabulary_size {vocabulary_size}')
376 ##############################
379 vocabulary_size = vocabulary_size,
380 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
381 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
386 nb_parameters = sum(p.numel() for p in model.parameters())
387 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
389 ######################################################################
391 if args.optim == 'sgd':
392 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
393 elif args.optim == 'adam':
394 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
395 elif args.optim == 'adamw':
396 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
398 raise ValueError(f'Unknown optimizer {args.optim}.')
400 ######################################################################
402 nb_epochs_finished = 0
405 checkpoint = torch.load(args.checkpoint_name, map_location = device)
406 nb_epochs_finished = checkpoint['nb_epochs_finished']
407 model.load_state_dict(checkpoint['model_state'])
408 optimizer.load_state_dict(checkpoint['optimizer_state'])
409 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
411 except FileNotFoundError:
412 print('Starting from scratch.')
415 print('Error when loading the checkpoint.')
418 ######################################################################
420 for k in range(nb_epochs_finished, args.nb_epochs):
424 nb_train_samples, acc_train_loss = 0, 0.0
426 for input in task.batches(split = 'train'):
427 input = input.to(device)
428 output = model(input)
429 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
430 acc_train_loss += loss.item() * input.size(0)
431 nb_train_samples += input.size(0)
433 optimizer.zero_grad()
437 with torch.autograd.no_grad():
441 nb_test_samples, acc_test_loss = 0, 0.0
443 for input in task.batches(split = 'test'):
444 input = input.to(device)
445 output = model(input)
446 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
447 acc_test_loss += loss.item() * input.size(0)
448 nb_test_samples += input.size(0)
450 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
451 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
453 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
455 task.produce_results(k, model)
458 'nb_epochs_finished': k + 1,
459 'model_state': model.state_dict(),
460 'optimizer_state': optimizer.state_dict()
463 torch.save(checkpoint, args.checkpoint_name)
465 ######################################################################