3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 action='store_true', default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 action='store_true', default = True)
72 parser.add_argument('--checkpoint_name',
73 type = str, default = 'checkpoint.pth')
75 parser.add_argument('--picoclvr_many_colors',
76 action='store_true', default = False)
78 ######################################################################
80 args = parser.parse_args()
82 log_file = open(args.log_filename, 'w')
85 torch.manual_seed(args.seed)
87 ######################################################################
90 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
92 if log_file is not None:
93 log_file.write(t + s + '\n')
100 log_string(f'args.{n} {getattr(args, n)}')
102 ######################################################################
105 def batches(self, split = 'train'):
108 def vocabulary_size(self):
111 def produce_results(self, n_epoch, model, nb_tokens = 50):
114 ######################################################################
118 class TaskPicoCLVR(Task):
120 def __init__(self, batch_size,
121 height = 6, width = 8, many_colors = False,
122 device = torch.device('cpu')):
124 self.batch_size = batch_size
126 nb = args.data_size if args.data_size > 0 else 250000
128 descr = picoclvr.generate(
130 height = height, width = width,
131 many_colors = many_colors
134 # self.test_descr = descr[:nb // 5]
135 # self.train_descr = descr[nb // 5:]
137 descr = [ s.strip().split(' ') for s in descr ]
138 l = max([ len(s) for s in descr ])
139 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
143 for t in s: tokens.add(t)
144 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
145 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
147 t = [ [ self.token2id[u] for u in s ] for s in descr ]
148 data_input = torch.tensor(t, device = self.device)
150 self.test_input = data_input[:nb // 5]
151 self.train_input = data_input[nb // 5:]
153 def batches(self, split = 'train'):
154 assert split in { 'train', 'test' }
156 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
159 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
162 def vocabulary_size(self):
163 return len(self.token2id)
165 def generate(self, primer, model, nb_tokens):
166 t_primer = primer.strip().split(' ')
169 for j in range(nb_tokens):
170 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
171 input = torch.tensor(t, device = self.device)
172 output = model(input)
173 logits = output[0, -1]
174 if args.synthesis_sampling:
175 dist = torch.distributions.categorical.Categorical(logits = logits)
179 t_generated.append(self.id2token[t.item()])
181 return ' '.join(t_primer + t_generated)
183 def produce_results(self, n_epoch, model, nb_tokens = 50):
188 'red above green <sep> green top <sep> blue right of red <img>',
189 'there is red <sep> there is yellow <sep> there is blue <img>',
190 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
191 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
194 for k in range(nb_per_primer):
195 descr.append(self.generate(primer, model, nb_tokens))
197 img = [ picoclvr.descr2img(d) for d in descr ]
198 img = torch.cat(img, 0)
199 file_name = f'result_picoclvr_{n_epoch:04d}.png'
200 torchvision.utils.save_image(img / 255.,
201 file_name, nrow = nb_per_primer, pad_value = 0.8)
202 log_string(f'wrote {file_name}')
204 log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}')
206 ######################################################################
208 class TaskWiki103(Task):
210 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
211 device = torch.device('cpu')):
213 self.batch_size = batch_size
214 self.len_min = len_min
215 self.len_max = len_max
216 self.min_freq = min_freq
219 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
220 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
223 if args.data_size > 0:
224 train_iter = itertools.islice(train_iter, args.data_size)
227 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
228 yield self.tokenizer(l)
230 self.vocab = torchtext.vocab.build_vocab_from_iterator(
232 specials = [ '<unk>', '<non>' ],
233 min_freq = self.min_freq
236 self.vocab.set_default_index(self.vocab[ '<unk>' ])
238 def tensorize(self, s):
239 a = max(len(x) for x in s)
240 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
242 def yield_batches(self, ds):
245 q = self.tokenizer(l)
246 if len(q) >= self.len_min and len(q) <= self.len_max:
248 if len(s) == self.batch_size:
249 yield self.tensorize(s)
253 yield self.tensorize(s)
255 def batches(self, split = 'train'):
256 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
259 if args.data_size > 0:
260 data_iter = itertools.islice(data_iter, args.data_size)
262 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
264 def vocabulary_size(self):
265 return len(self.vocab)
267 def produce_results(self, n_epoch, model, nb_tokens = 50):
268 file_name = f'result_wiki103_{n_epoch:04d}.txt'
270 with open(file_name, 'w') as outfile:
272 'the cat is hunting a',
273 'paris is the capital',
274 'cars are convenient',
275 'the difference between men and women is',
276 'the object was blue all over and green all over it was',
277 'cherries are red and lemons are',
278 'cherries are sweet and lemons are',
279 'two plus three equals',
282 t_primer = self.tokenizer(primer)
285 for j in range(nb_tokens):
287 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
288 output = model(input)
289 logits = output[0, -1]
290 if args.synthesis_sampling:
291 dist = torch.distributions.categorical.Categorical(logits = logits)
295 t_generated.append(self.vocab.lookup_token(t))
296 if t_generated[-1] == '<non>': break
298 s = ' '.join(t_generated)
300 outfile.write(f'<{primer}> {s}\n')
302 log_string(f'wrote {file_name}')
304 ######################################################################
306 class TaskMNIST(Task):
308 def __init__(self, batch_size, device = torch.device('cpu')):
310 self.batch_size = batch_size
312 def batches(self, split = 'train'):
313 assert split in { 'train', 'test' }
314 data_set = torchvision.datasets.MNIST(
315 root = './data', train = (split == 'train'),
318 data_input = data_set.data.view(-1, 28 * 28).long()
319 if args.data_size >= 0:
320 data_input = data_input[:args.data_size]
321 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
324 def vocabulary_size(self):
327 def produce_results(self, n_epoch, model, nb_samples = 64):
328 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
329 for input in results.split(self.batch_size):
330 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
331 output = model(input)
332 logits = output[:, s]
333 if args.synthesis_sampling:
334 dist = torch.distributions.categorical.Categorical(logits = logits)
340 image_name = f'result_mnist_{n_epoch:04d}.png'
341 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
342 image_name, nrow = 16, pad_value = 0.8)
343 log_string(f'wrote {image_name}')
345 ######################################################################
347 def check_causality(model):
349 input = torch.rand(1, 5, dim_model).requires_grad_()
351 a = torch.zeros(output.size(1), input.size(1))
352 for k in range(output.size(1)):
353 for d in range(output.size(2)):
354 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
355 a[k] += g.squeeze(0).pow(2).sum(1)
358 ######################################################################
360 log_string(f'device {device}')
362 if args.data == 'wiki103':
363 task = TaskWiki103(batch_size = args.batch_size, device = device)
364 elif args.data == 'mnist':
365 task = TaskMNIST(batch_size = args.batch_size, device = device)
366 elif args.data == 'picoclvr':
367 task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
369 raise ValueError(f'Unknown dataset {args.data}.')
371 vocabulary_size = task.vocabulary_size()
373 log_string(f'vocabulary_size {vocabulary_size}')
375 ##############################
378 vocabulary_size = vocabulary_size,
379 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
380 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
385 nb_parameters = sum(p.numel() for p in model.parameters())
386 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
388 ######################################################################
390 if args.optim == 'sgd':
391 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
392 elif args.optim == 'adam':
393 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
394 elif args.optim == 'adamw':
395 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
397 raise ValueError(f'Unknown optimizer {args.optim}.')
399 ######################################################################
401 nb_epochs_finished = 0
404 checkpoint = torch.load(args.checkpoint_name, map_location = device)
405 nb_epochs_finished = checkpoint['nb_epochs_finished']
406 model.load_state_dict(checkpoint['model_state'])
407 optimizer.load_state_dict(checkpoint['optimizer_state'])
408 print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
410 except FileNotFoundError:
411 print('Starting from scratch.')
414 print('Error when loading the checkpoint.')
417 ######################################################################
419 for k in range(nb_epochs_finished, args.nb_epochs):
423 nb_train_samples, acc_train_loss = 0, 0.0
425 for input in task.batches(split = 'train'):
426 input = input.to(device)
427 output = model(input)
428 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
429 acc_train_loss += loss.item() * input.size(0)
430 nb_train_samples += input.size(0)
432 optimizer.zero_grad()
436 with torch.autograd.no_grad():
440 nb_test_samples, acc_test_loss = 0, 0.0
442 for input in task.batches(split = 'test'):
443 input = input.to(device)
444 output = model(input)
445 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
446 acc_test_loss += loss.item() * input.size(0)
447 nb_test_samples += input.size(0)
449 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
450 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
452 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
454 task.produce_results(k, model)
457 'nb_epochs_finished': k + 1,
458 'model_state': model.state_dict(),
459 'optimizer_state': optimizer.state_dict()
462 torch.save(checkpoint, args.checkpoint_name)
464 ######################################################################