3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 parser.add_argument('--picoclvr_many_colors',
76 action='store_true', default = False)
78 ######################################################################
80 args = parser.parse_args()
82 log_file = open(args.log_filename, 'w')
85 torch.manual_seed(args.seed)
87 ######################################################################
90 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
92 if log_file is not None:
93 log_file.write(t + s + '\n')
100 log_string(f'args.{n} {getattr(args, n)}')
102 ######################################################################
105 def batches(self, split = 'train'):
108 def vocabulary_size(self):
111 def produce_results(self, n_epoch, model, nb_tokens = 50):
114 ######################################################################
118 class TaskPicoCLVR(Task):
120 def __init__(self, batch_size,
121 height = 6, width = 8, many_colors = False,
122 device = torch.device('cpu')):
124 self.batch_size = batch_size
126 nb = args.data_size if args.data_size > 0 else 250000
128 descr = picoclvr.generate(
130 height = height, width = width,
131 many_colors = many_colors
134 descr = [ s.strip().split(' ') for s in descr ]
135 l = max([ len(s) for s in descr ])
136 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
140 for t in s: tokens.add(t)
141 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
142 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
144 t = [ [ self.token2id[u] for u in s ] for s in descr ]
145 data_input = torch.tensor(t, device = self.device)
147 self.test_input = data_input[:nb // 5]
148 self.train_input = data_input[nb // 5:]
150 def batches(self, split = 'train'):
151 assert split in { 'train', 'test' }
153 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
156 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
159 def vocabulary_size(self):
160 return len(self.token2id)
162 def produce_results(self, n_epoch, model, nb_tokens = 50):
167 'red above green <sep> green top <sep> blue right of red <img>',
168 'there is red <sep> there is yellow <sep> there is blue <img>',
169 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
170 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
173 for k in range(nb_per_primer):
174 t_primer = primer.strip().split(' ')
177 for j in range(nb_tokens):
178 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
179 input = torch.tensor(t, device = self.device)
180 output = model(input)
181 logits = output[0, -1]
182 if args.synthesis_sampling:
183 dist = torch.distributions.categorical.Categorical(logits = logits)
187 t_generated.append(self.id2token[t.item()])
189 descr = [ ' '.join(t_primer + t_generated) ]
190 img += [ picoclvr.descr2img(descr) ]
192 img = torch.cat(img, 0)
193 file_name = f'result_picoclvr_{n_epoch:04d}.png'
194 torchvision.utils.save_image(img / 255.,
195 file_name, nrow = nb_per_primer, pad_value = 0.8)
196 log_string(f'wrote {file_name}')
198 ######################################################################
200 class TaskWiki103(Task):
202 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
203 device = torch.device('cpu')):
205 self.batch_size = batch_size
206 self.len_min = len_min
207 self.len_max = len_max
208 self.min_freq = min_freq
211 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
212 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
215 if args.data_size > 0:
216 train_iter = itertools.islice(train_iter, args.data_size)
219 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
220 yield self.tokenizer(l)
222 self.vocab = torchtext.vocab.build_vocab_from_iterator(
224 specials = [ '<unk>', '<non>' ],
225 min_freq = self.min_freq
228 self.vocab.set_default_index(self.vocab[ '<unk>' ])
230 def tensorize(self, s):
231 a = max(len(x) for x in s)
232 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
234 def yield_batches(self, ds):
237 q = self.tokenizer(l)
238 if len(q) >= self.len_min and len(q) <= self.len_max:
240 if len(s) == self.batch_size:
241 yield self.tensorize(s)
245 yield self.tensorize(s)
247 def batches(self, split = 'train'):
248 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
251 if args.data_size > 0:
252 data_iter = itertools.islice(data_iter, args.data_size)
254 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
256 def vocabulary_size(self):
257 return len(self.vocab)
259 def produce_results(self, n_epoch, model, nb_tokens = 50):
260 file_name = f'result_wiki103_{n_epoch:04d}.txt'
262 with open(file_name, 'w') as outfile:
264 'the cat is hunting a',
265 'paris is the capital',
266 'cars are convenient',
267 'the difference between men and women is',
268 'the object was blue all over and green all over it was',
269 'cherries are red and lemons are',
270 'cherries are sweet and lemons are',
271 'two plus three equals',
274 t_primer = self.tokenizer(primer)
277 for j in range(nb_tokens):
279 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
280 output = model(input)
281 logits = output[0, -1]
282 if args.synthesis_sampling:
283 dist = torch.distributions.categorical.Categorical(logits = logits)
287 t_generated.append(self.vocab.lookup_token(t))
288 if t_generated[-1] == '<non>': break
290 s = ' '.join(t_generated)
292 outfile.write(f'<{primer}> {s}\n')
294 log_string(f'wrote {file_name}')
296 ######################################################################
298 class TaskMNIST(Task):
300 def __init__(self, batch_size, device = torch.device('cpu')):
302 self.batch_size = batch_size
304 def batches(self, split = 'train'):
305 assert split in { 'train', 'test' }
306 data_set = torchvision.datasets.MNIST(
307 root = './data', train = (split == 'train'),
310 data_input = data_set.data.view(-1, 28 * 28).long()
311 if args.data_size >= 0:
312 data_input = data_input[:args.data_size]
313 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
316 def vocabulary_size(self):
319 def produce_results(self, n_epoch, model, nb_samples = 64):
320 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
321 for input in results.split(self.batch_size):
322 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
323 output = model(input)
324 logits = output[:, s]
325 if args.synthesis_sampling:
326 dist = torch.distributions.categorical.Categorical(logits = logits)
332 image_name = f'result_mnist_{n_epoch:04d}.png'
333 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
334 image_name, nrow = 16, pad_value = 0.8)
335 log_string(f'wrote {image_name}')
337 ######################################################################
339 def check_causality(model):
341 input = torch.rand(1, 5, dim_model).requires_grad_()
343 a = torch.zeros(output.size(1), input.size(1))
344 for k in range(output.size(1)):
345 for d in range(output.size(2)):
346 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
347 a[k] += g.squeeze(0).pow(2).sum(1)
350 ######################################################################
352 log_string(f'device {device}')
354 if args.data == 'wiki103':
355 task = TaskWiki103(batch_size = args.batch_size, device = device)
356 elif args.data == 'mnist':
357 task = TaskMNIST(batch_size = args.batch_size, device = device)
358 elif args.data == 'picoclvr':
359 task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
361 raise ValueError(f'Unknown dataset {args.data}.')
363 vocabulary_size = task.vocabulary_size()
365 log_string(f'vocabulary_size {vocabulary_size}')
367 ##############################
370 vocabulary_size = vocabulary_size,
371 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
372 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
377 nb_parameters = sum(p.numel() for p in model.parameters())
378 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
380 ######################################################################
382 if args.optim == 'sgd':
383 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
384 elif args.optim == 'adam':
385 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
386 elif args.optim == 'adamw':
387 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
389 raise ValueError(f'Unknown optimizer {args.optim}.')
391 ######################################################################
393 nb_epochs_finished = 0
396 checkpoint = torch.load(args.checkpoint_name, map_location = device)
397 nb_epochs_finished = checkpoint['nb_epochs_finished']
398 model.load_state_dict(checkpoint['model_state'])
399 optimizer.load_state_dict(checkpoint['optimizer_state'])
400 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
402 except FileNotFoundError:
403 print('Starting from scratch.')
406 print('Error when loading the checkpoint.')
409 ######################################################################
411 for k in range(nb_epochs_finished, args.nb_epochs):
415 nb_train_samples, acc_train_loss = 0, 0.0
417 for input in task.batches(split = 'train'):
418 input = input.to(device)
419 output = model(input)
420 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
421 acc_train_loss += loss.item() * input.size(0)
422 nb_train_samples += input.size(0)
424 optimizer.zero_grad()
428 with torch.autograd.no_grad():
432 nb_test_samples, acc_test_loss = 0, 0.0
434 for input in task.batches(split = 'test'):
435 input = input.to(device)
436 output = model(input)
437 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
438 acc_test_loss += loss.item() * input.size(0)
439 nb_test_samples += input.size(0)
441 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
442 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
444 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
446 task.produce_results(k, model)
449 'nb_epochs_finished': k + 1,
450 'model_state': model.state_dict(),
451 'optimizer_state': optimizer.state_dict()
454 torch.save(checkpoint, args.checkpoint_name)
456 ######################################################################