3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-4)
47 parser.add_argument('--dim_model',
48 type = int, default = 512)
50 parser.add_argument('--dim_keys',
51 type = int, default = 64)
53 parser.add_argument('--dim_hidden',
54 type = int, default = 2048)
56 parser.add_argument('--nb_heads',
57 type = int, default = 8)
59 parser.add_argument('--nb_blocks',
60 type = int, default = 12)
62 parser.add_argument('--dropout',
63 type = float, default = 0.1)
65 parser.add_argument('--synthesis_sampling',
66 action='store_true', default = True)
68 parser.add_argument('--no_checkpoint',
69 action='store_true', default = False)
71 parser.add_argument('--checkpoint_name',
72 type = str, default = 'checkpoint.pth')
74 ##############################
77 parser.add_argument('--picoclvr_nb_colors',
78 type = int, default = 5)
80 parser.add_argument('--picoclvr_height',
81 type = int, default = 12)
83 parser.add_argument('--picoclvr_width',
84 type = int, default = 16)
86 ######################################################################
88 args = parser.parse_args()
90 log_file = open(args.log_filename, 'w')
93 torch.manual_seed(args.seed)
95 ######################################################################
98 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
100 if log_file is not None:
101 log_file.write(t + s + '\n')
108 log_string(f'args.{n} {getattr(args, n)}')
110 ######################################################################
114 nb_samples, nb_tokens_to_generate, primer = None,
115 device = torch.device('cpu')
117 results = torch.zeros(
118 nb_samples, nb_tokens_to_generate,
119 dtype = torch.int64, device = device
125 first = primer.size(1)
126 results = torch.cat((primer, results), 1)
128 for input in results.split(batch_size):
129 for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
130 output = model(input)
131 logits = output[:, s]
132 if args.synthesis_sampling:
133 dist = torch.distributions.categorical.Categorical(logits = logits)
134 t_next = dist.sample()
136 t_next = logits.argmax(1)
141 ######################################################################
144 def batches(self, split = 'train'):
147 def vocabulary_size(self):
150 def produce_results(self, n_epoch, model):
153 ######################################################################
157 class TaskPicoCLVR(Task):
159 def tensorize(self, descr):
160 t = [ [ self.token2id[u] for u in s ] for s in descr ]
161 return torch.tensor(t, device = self.device)
163 def __init__(self, batch_size,
164 height, width, nb_colors = 5,
165 device = torch.device('cpu')):
167 def generate_descr(nb):
168 descr = picoclvr.generate(
170 height = self.height, width = self.width,
171 nb_colors = nb_colors
174 descr = [ s.strip().split(' ') for s in descr ]
175 l = max([ len(s) for s in descr ])
176 #descr = [ [ '<nul>' ] * (l - len(s)) + s for s in descr ]
177 descr = [ s + [ '<nul>' ] * (l - len(s)) for s in descr ]
183 self.batch_size = batch_size
185 nb = args.data_size if args.data_size > 0 else 250000
187 self.train_descr = generate_descr((nb * 4) // 5)
188 self.test_descr = generate_descr((nb * 1) // 5)
190 # Build the tokenizer
192 for d in [ self.train_descr, self.test_descr ]:
194 for t in s: tokens.add(t)
195 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
196 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
198 # Tokenize the train and test sets
199 self.train_input = self.tensorize(self.train_descr)
200 self.test_input = self.tensorize(self.test_descr)
202 def batches(self, split = 'train'):
203 assert split in { 'train', 'test' }
204 input = self.train_input if split == 'train' else self.test_input
205 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
208 def vocabulary_size(self):
209 return len(self.token2id)
211 def produce_results(self, n_epoch, model):
212 nb_tokens = self.height * self.width + 3
216 for primer_descr in [
217 'red above green <sep> green top <sep> blue right of red <img>',
218 'there is red <sep> there is yellow <sep> there is blue <img>',
219 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
220 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
223 for k in range(nb_per_primer):
224 results = autoregression(
225 model, self.batch_size,
226 nb_samples = 1, nb_tokens = nb_tokens,
227 primer = self.tensorize(primer_descr),
230 r = ' '.join([ self.id2token[t.item()] for t in results.flatten() ])
231 result_descr.append(r)
234 picoclvr.descr2img(d, height = self.height, width = self.width)
235 for d in result_descr
238 img = torch.cat(img, 0)
239 image_name = f'result_picoclvr_{n_epoch:04d}.png'
240 torchvision.utils.save_image(
242 image_name, nrow = nb_per_primer, pad_value = 0.8
244 log_string(f'wrote {image_name}')
246 np = picoclvr.nb_properties(
248 height = self.height, width = self.width
251 nb_requested_properties, _, nb_missing_properties = zip(*np)
253 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
255 ######################################################################
257 class TaskWiki103(Task):
259 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
260 device = torch.device('cpu')):
262 self.batch_size = batch_size
263 self.len_min = len_min
264 self.len_max = len_max
265 self.min_freq = min_freq
268 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
269 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
272 if args.data_size > 0:
273 train_iter = itertools.islice(train_iter, args.data_size)
276 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
277 yield self.tokenizer(l)
279 self.vocab = torchtext.vocab.build_vocab_from_iterator(
281 specials = [ '<unk>', '<nul>' ],
282 min_freq = self.min_freq
285 self.vocab.set_default_index(self.vocab[ '<unk>' ])
287 def tensorize(self, s):
288 a = max(len(x) for x in s)
289 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
291 def yield_batches(self, ds):
294 q = self.tokenizer(l)
295 if len(q) >= self.len_min and len(q) <= self.len_max:
297 if len(s) == self.batch_size:
298 yield self.tensorize(s)
302 yield self.tensorize(s)
304 def batches(self, split = 'train'):
305 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
308 if args.data_size > 0:
309 data_iter = itertools.islice(data_iter, args.data_size)
311 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
313 def vocabulary_size(self):
314 return len(self.vocab)
316 def produce_results(self, n_epoch, model):
318 file_name = f'result_wiki103_{n_epoch:04d}.txt'
320 with open(file_name, 'w') as outfile:
322 'the cat is hunting a',
323 'paris is the capital',
324 'cars are convenient',
325 'the difference between men and women is',
326 'the object was blue all over and green all over it was',
327 'cherries are red and lemons are',
328 'cherries are sweet and lemons are',
329 'two plus three equals',
332 t_primer = self.tokenizer(primer)
335 for j in range(nb_tokens):
337 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
338 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
339 output = model(input)
340 logits = output[0, -1]
341 if args.synthesis_sampling:
342 dist = torch.distributions.categorical.Categorical(logits = logits)
343 t_next = dist.sample()
345 t_next = logits.argmax()
346 t_generated.append(self.vocab.lookup_token(t_next))
347 if t_generated[-1] == '<nul>': break
349 s = ' '.join(t_generated)
351 outfile.write(f'<{primer}> {s}\n')
353 log_string(f'wrote {file_name}')
355 ######################################################################
357 class TaskMNIST(Task):
359 def __init__(self, batch_size, device = torch.device('cpu')):
361 self.batch_size = batch_size
363 def batches(self, split = 'train'):
364 assert split in { 'train', 'test' }
365 data_set = torchvision.datasets.MNIST(
366 root = './data', train = (split == 'train'),
369 data_input = data_set.data.view(-1, 28 * 28).long()
370 if args.data_size >= 0:
371 data_input = data_input[:args.data_size]
372 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
375 def vocabulary_size(self):
378 def produce_results(self, n_epoch, model):
380 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
381 image_name = f'result_mnist_{n_epoch:04d}.png'
382 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
383 image_name, nrow = 16, pad_value = 0.8)
384 log_string(f'wrote {image_name}')
386 ######################################################################
388 log_string(f'device {device}')
390 if args.data == 'wiki103':
391 nb_epochs_default = 10
392 task = TaskWiki103(batch_size = args.batch_size, device = device)
393 elif args.data == 'mnist':
394 nb_epochs_default = 25
395 task = TaskMNIST(batch_size = args.batch_size, device = device)
396 elif args.data == 'picoclvr':
397 nb_epochs_default = 10
398 task = TaskPicoCLVR(batch_size = args.batch_size,
399 height = args.picoclvr_height,
400 width = args.picoclvr_width,
401 nb_colors = args.picoclvr_nb_colors,
404 raise ValueError(f'Unknown dataset {args.data}.')
406 vocabulary_size = task.vocabulary_size()
408 log_string(f'vocabulary_size {vocabulary_size}')
410 ##############################
413 vocabulary_size = vocabulary_size,
414 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
415 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
420 nb_parameters = sum(p.numel() for p in model.parameters())
421 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
423 ######################################################################
425 if args.optim == 'sgd':
426 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
427 elif args.optim == 'adam':
428 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
429 elif args.optim == 'adamw':
430 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
432 raise ValueError(f'Unknown optimizer {args.optim}.')
434 ######################################################################
436 nb_epochs_finished = 0
438 if args.no_checkpoint:
439 log_string(f'not trying to load checkpoint.')
443 checkpoint = torch.load(args.checkpoint_name, map_location = device)
444 nb_epochs_finished = checkpoint['nb_epochs_finished']
445 model.load_state_dict(checkpoint['model_state'])
446 optimizer.load_state_dict(checkpoint['optimizer_state'])
447 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
449 except FileNotFoundError:
450 log_string('starting from scratch.')
453 log_string('error when loading the checkpoint.')
456 ######################################################################
458 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
461 for input in task.batches(split = 'train'):
462 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
463 token_probas = token_count / token_count.sum()
464 entropy = -torch.xlogy(token_probas, token_probas).sum()
465 train_set_perplexity = math.exp(entropy)
466 #log_string(f'train set perplexity {train_set_perplexity}')
468 for k in range(nb_epochs_finished, nb_epochs):
472 nb_train_samples, acc_train_loss = 0, 0.0
474 for input in task.batches(split = 'train'):
475 input = input.to(device)
476 output = model(input)
477 loss = F.cross_entropy(output.transpose(1, 2), input)
478 acc_train_loss += loss.item() * input.size(0)
479 nb_train_samples += input.size(0)
481 optimizer.zero_grad()
485 with torch.autograd.no_grad():
489 nb_test_samples, acc_test_loss = 0, 0.0
491 for input in task.batches(split = 'test'):
492 input = input.to(device)
493 output = model(input)
494 loss = F.cross_entropy(output.transpose(1, 2), input)
495 acc_test_loss += loss.item() * input.size(0)
496 nb_test_samples += input.size(0)
498 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
499 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
501 log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
503 task.produce_results(k, model)
506 'nb_epochs_finished': k + 1,
507 'model_state': model.state_dict(),
508 'optimizer_state': optimizer.state_dict()
511 torch.save(checkpoint, args.checkpoint_name)
513 ######################################################################