3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 type = bool, default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 type = bool, default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 ######################################################################
77 args = parser.parse_args()
79 log_file = open(args.log_filename, 'w')
82 torch.manual_seed(args.seed)
84 ######################################################################
87 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
89 if log_file is not None:
90 log_file.write(t + s + '\n')
97 log_string(f'args.{n} {getattr(args, n)}')
99 ######################################################################
102 def batches(self, split = 'train'):
105 def vocabulary_size(self):
108 def produce_results(self, n_epoch, model, nb_tokens = 50):
111 ######################################################################
115 class TaskPicoCLVR(Task):
117 def __init__(self, batch_size,
118 height = 6, width = 8, many_colors = False,
119 device = torch.device('cpu')):
121 self.batch_size = batch_size
123 nb = args.data_size if args.data_size > 0 else 250000
125 descr = picoclvr.generate(
127 height = height, width = width,
128 many_colors = many_colors
131 descr = [ s.strip().split(' ') for s in descr ]
132 l = max([ len(s) for s in descr ])
133 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
137 for t in s: tokens.add(t)
138 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
139 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
141 t = [ [ self.token2id[u] for u in s ] for s in descr ]
142 data_input = torch.tensor(t, device = self.device)
144 self.test_input = data_input[:nb // 5]
145 self.train_input = data_input[nb // 5:]
147 def batches(self, split = 'train'):
148 assert split in { 'train', 'test' }
150 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
153 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
156 def vocabulary_size(self):
157 return len(self.token2id)
159 def produce_results(self, n_epoch, model, nb_tokens = 50):
164 'red above green <sep> green top <sep> blue right of red <img>',
165 'there is red <sep> there is yellow <sep> there is blue <img>',
166 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
167 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
170 for k in range(nb_per_primer):
171 t_primer = primer.strip().split(' ')
174 for j in range(nb_tokens):
175 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
176 input = torch.tensor(t, device = self.device)
177 output = model(input)
178 logits = output[0, -1]
179 if args.synthesis_sampling:
180 dist = torch.distributions.categorical.Categorical(logits = logits)
184 t_generated.append(self.id2token[t.item()])
186 descr = [ ' '.join(t_primer + t_generated) ]
187 img += [ picoclvr.descr2img(descr) ]
189 img = torch.cat(img, 0)
190 file_name = f'result_picoclvr_{n_epoch:04d}.png'
191 torchvision.utils.save_image(img / 255.,
192 file_name, nrow = nb_per_primer, pad_value = 0.8)
193 log_string(f'wrote {file_name}')
195 ######################################################################
197 class TaskWiki103(Task):
199 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
200 device = torch.device('cpu')):
202 self.batch_size = batch_size
203 self.len_min = len_min
204 self.len_max = len_max
205 self.min_freq = min_freq
208 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
209 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
212 if args.data_size > 0:
213 train_iter = itertools.islice(train_iter, args.data_size)
216 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
217 yield self.tokenizer(l)
219 self.vocab = torchtext.vocab.build_vocab_from_iterator(
221 specials = [ '<unk>', '<non>' ],
222 min_freq = self.min_freq
225 self.vocab.set_default_index(self.vocab[ '<unk>' ])
227 def tensorize(self, s):
228 a = max(len(x) for x in s)
229 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
231 def yield_batches(self, ds):
234 q = self.tokenizer(l)
235 if len(q) >= self.len_min and len(q) <= self.len_max:
237 if len(s) == self.batch_size:
238 yield self.tensorize(s)
242 yield self.tensorize(s)
244 def batches(self, split = 'train'):
245 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
248 if args.data_size > 0:
249 data_iter = itertools.islice(data_iter, args.data_size)
251 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
253 def vocabulary_size(self):
254 return len(self.vocab)
256 def produce_results(self, n_epoch, model, nb_tokens = 50):
257 file_name = f'result_wiki103_{n_epoch:04d}.txt'
259 with open(file_name, 'w') as outfile:
261 'the cat is hunting a',
262 'paris is the capital',
263 'cars are convenient',
264 'the difference between men and women is',
265 'the object was blue all over and green all over it was',
266 'cherries are red and lemons are',
267 'cherries are sweet and lemons are',
268 'two plus three equals',
271 t_primer = self.tokenizer(primer)
274 for j in range(nb_tokens):
276 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
277 output = model(input)
278 logits = output[0, -1]
279 if args.synthesis_sampling:
280 dist = torch.distributions.categorical.Categorical(logits = logits)
284 t_generated.append(self.vocab.lookup_token(t))
285 if t_generated[-1] == '<non>': break
287 s = ' '.join(t_generated)
289 outfile.write(f'<{primer}> {s}\n')
291 log_string(f'wrote {file_name}')
293 ######################################################################
295 class TaskMNIST(Task):
297 def __init__(self, batch_size, device = torch.device('cpu')):
299 self.batch_size = batch_size
301 def batches(self, split = 'train'):
302 assert split in { 'train', 'test' }
303 data_set = torchvision.datasets.MNIST(
304 root = './data', train = (split == 'train'),
307 data_input = data_set.data.view(-1, 28 * 28).long()
308 if args.data_size >= 0:
309 data_input = data_input[:args.data_size]
310 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
313 def vocabulary_size(self):
316 def produce_results(self, n_epoch, model, nb_samples = 64):
317 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
318 for input in results.split(self.batch_size):
319 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
320 output = model(input)
321 logits = output[:, s]
322 if args.synthesis_sampling:
323 dist = torch.distributions.categorical.Categorical(logits = logits)
329 image_name = f'result_mnist_{n_epoch:04d}.png'
330 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
331 image_name, nrow = 16, pad_value = 0.8)
332 log_string(f'wrote {image_name}')
334 ######################################################################
336 def check_causality(model):
338 input = torch.rand(1, 5, dim_model).requires_grad_()
340 a = torch.zeros(output.size(1), input.size(1))
341 for k in range(output.size(1)):
342 for d in range(output.size(2)):
343 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
344 a[k] += g.squeeze(0).pow(2).sum(1)
347 ######################################################################
349 log_string(f'device {device}')
351 if args.data == 'wiki103':
352 task = TaskWiki103(batch_size = args.batch_size, device = device)
353 elif args.data == 'mnist':
354 task = TaskMNIST(batch_size = args.batch_size, device = device)
355 elif args.data == 'picoclvr':
356 task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
358 raise ValueError(f'Unknown dataset {args.data}.')
360 vocabulary_size = task.vocabulary_size()
362 log_string(f'vocabulary_size {vocabulary_size}')
364 ##############################
367 vocabulary_size = vocabulary_size,
368 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
369 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
374 nb_parameters = sum(p.numel() for p in model.parameters())
375 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
377 ######################################################################
379 if args.optim == 'sgd':
380 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
381 elif args.optim == 'adam':
382 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
383 elif args.optim == 'adamw':
384 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
386 raise ValueError(f'Unknown optimizer {args.optim}.')
388 ######################################################################
390 nb_epochs_finished = 0
393 checkpoint = torch.load(args.checkpoint_name, map_location = device)
394 nb_epochs_finished = checkpoint['nb_epochs_finished']
395 model.load_state_dict(checkpoint['model_state'])
396 optimizer.load_state_dict(checkpoint['optimizer_state'])
397 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
399 except FileNotFoundError:
400 print('Starting from scratch.')
403 print('Error when loading the checkpoint.')
406 ######################################################################
408 for k in range(nb_epochs_finished, args.nb_epochs):
412 nb_train_samples, acc_train_loss = 0, 0.0
414 for input in task.batches(split = 'train'):
415 input = input.to(device)
416 output = model(input)
417 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
418 acc_train_loss += loss.item() * input.size(0)
419 nb_train_samples += input.size(0)
421 optimizer.zero_grad()
425 with torch.autograd.no_grad():
429 nb_test_samples, acc_test_loss = 0, 0.0
431 for input in task.batches(split = 'test'):
432 input = input.to(device)
433 output = model(input)
434 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
435 acc_test_loss += loss.item() * input.size(0)
436 nb_test_samples += input.size(0)
438 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
439 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
441 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
443 task.produce_results(k, model)
446 'nb_epochs_finished': k + 1,
447 'model_state': model.state_dict(),
448 'optimizer_state': optimizer.state_dict()
451 torch.save(checkpoint, args.checkpoint_name)
453 ######################################################################