3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
21 parser = argparse.ArgumentParser(description = 'My own GPT.')
23 parser.add_argument('--log_filename',
24 type = str, default = 'train.log')
26 parser.add_argument('--seed',
27 type = int, default = 0)
29 parser.add_argument('--nb_epochs',
30 type = int, default = -1)
32 parser.add_argument('--batch_size',
33 type = int, default = 25)
35 parser.add_argument('--data',
36 type = str, default = 'wiki103')
38 parser.add_argument('--data_size',
39 type = int, default = -1)
41 parser.add_argument('--optim',
42 type = str, default = 'adam')
44 parser.add_argument('--learning_rate',
45 type = float, default = 1e-3)
47 parser.add_argument('--learning_rate_end',
48 type = float, default = 1e-6)
50 parser.add_argument('--dim_model',
51 type = int, default = 512)
53 parser.add_argument('--dim_keys',
54 type = int, default = 64)
56 parser.add_argument('--dim_hidden',
57 type = int, default = 2048)
59 parser.add_argument('--nb_heads',
60 type = int, default = 8)
62 parser.add_argument('--nb_blocks',
63 type = int, default = 12)
65 parser.add_argument('--dropout',
66 type = float, default = 0.1)
68 parser.add_argument('--deterministic_synthesis',
69 action='store_true', default = False)
71 parser.add_argument('--no_checkpoint',
72 action='store_true', default = False)
74 parser.add_argument('--checkpoint_name',
75 type = str, default = 'checkpoint.pth')
77 ##############################
80 parser.add_argument('--picoclvr_nb_colors',
81 type = int, default = 5)
83 parser.add_argument('--picoclvr_height',
84 type = int, default = 12)
86 parser.add_argument('--picoclvr_width',
87 type = int, default = 16)
89 ######################################################################
91 args = parser.parse_args()
93 log_file = open(args.log_filename, 'w')
96 torch.manual_seed(args.seed)
98 ######################################################################
101 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
103 if log_file is not None:
104 log_file.write(t + s + '\n')
111 log_string(f'args.{n} {getattr(args, n)}')
113 ######################################################################
117 nb_samples, nb_tokens_to_generate, primer = None,
118 device = torch.device('cpu')
120 results = torch.zeros(
121 nb_samples, nb_tokens_to_generate,
122 dtype = torch.int64, device = device
128 first = primer.size(1)
129 results = torch.cat((primer, results), 1)
131 for input in results.split(batch_size):
132 for s in range(first, input.size(1)):
133 output = model(input)
134 logits = output[:, s]
135 if args.deterministic_synthesis:
136 t_next = logits.argmax(1)
138 dist = torch.distributions.categorical.Categorical(logits = logits)
139 t_next = dist.sample()
144 ######################################################################
147 def batches(self, split = 'train'):
150 def vocabulary_size(self):
153 def produce_results(self, n_epoch, model):
156 ######################################################################
160 class TaskPicoCLVR(Task):
162 # Make a tensor from a list of strings
163 def tensorize(self, descr):
164 token_descr = [ s.strip().split(' ') for s in descr ]
165 l = max([ len(s) for s in token_descr ])
166 #token_descr = [ [ '<nul>' ] * (l - len(s)) + s for s in token_descr ]
167 token_descr = [ s + [ '<nul>' ] * (l - len(s)) for s in token_descr ]
168 id_descr = [ [ self.token2id[u] for u in s ] for s in token_descr ]
169 return torch.tensor(id_descr, device = self.device)
171 def trim(self, x, token = '<nul>'):
172 n = self.token2id[token]
173 i = (1 - (F.pad(x, (1, 1), value = n) == n).min(0).values.long()).cumsum(0)
174 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
177 def __init__(self, batch_size,
178 height, width, nb_colors = 5,
179 device = torch.device('cpu')):
181 def generate_descr(nb):
182 return picoclvr.generate(
184 height = self.height, width = self.width,
185 nb_colors = nb_colors
190 self.batch_size = batch_size
192 nb = args.data_size if args.data_size > 0 else 250000
194 log_string(f'generating {nb} samples (can take some time)')
195 self.train_descr = generate_descr((nb * 4) // 5)
196 self.test_descr = generate_descr((nb * 1) // 5)
198 # Build the tokenizer
200 for d in [ self.train_descr, self.test_descr ]:
202 for t in s.strip().split(' '): tokens.add(t)
203 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
204 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
206 # Tokenize the train and test sets
207 self.train_input = self.tensorize(self.train_descr)
208 self.test_input = self.tensorize(self.test_descr)
210 def batches(self, split = 'train'):
211 assert split in { 'train', 'test' }
212 input = self.train_input if split == 'train' else self.test_input
213 for batch in tqdm.tqdm(input.split(self.batch_size), desc = f'epoch-{split}'):
214 yield self.trim(batch)
216 def vocabulary_size(self):
217 return len(self.token2id)
219 def test_model(self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False):
220 nb_tokens_to_generate = self.height * self.width + 3
223 for primer_descr in primers_descr:
225 results = autoregression(
228 nb_samples = nb_per_primer,
229 nb_tokens_to_generate = nb_tokens_to_generate,
230 primer = self.tensorize([ primer_descr ]).expand(nb_per_primer, -1),
234 l = [ ' '.join([ self.id2token[t.item()] for t in r ]) for r in results ]
237 np = picoclvr.nb_properties(
239 height = self.height, width = self.width
242 nb_requested_properties, _, nb_missing_properties = zip(*np)
244 log_string(f'nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}')
247 count=torch.empty(np[:,0].max()+1,np[:,2].max()+1,dtype=torch.int64)
248 for i in range(count.size(0)):
249 for j in range(count.size(1)):
250 count[i,j]=((np[:,0]==i).long()*(np[:,2]==j).long()).sum()
254 picoclvr.descr2img(d, height = self.height, width = self.width)
255 for d in result_descr
258 img = torch.cat(img, 0)
259 image_name = f'result_picoclvr_{n_epoch:04d}.png'
260 torchvision.utils.save_image(
262 image_name, nrow = nb_per_primer, pad_value = 0.8
264 log_string(f'wrote {image_name}')
268 def produce_results(self, n_epoch, model):
270 'red above green <sep> green top <sep> blue right of red <img>',
271 'there is red <sep> there is yellow <sep> there is blue <img>',
272 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
273 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
279 nb_per_primer=8, generate_images=True
284 # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
286 # count=self.test_model(
288 # test_primers_descr,
289 # nb_per_primer=1, generate_images=False
292 # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
293 # for i in range(count.size(0)):
294 # for j in range(count.size(1)):
295 # f.write(f'{count[i,j]}')
296 # f.write(" " if j<count.size(1)-1 else "\n")
298 ######################################################################
300 class TaskWiki103(Task):
302 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
303 device = torch.device('cpu')):
305 self.batch_size = batch_size
306 self.len_min = len_min
307 self.len_max = len_max
308 self.min_freq = min_freq
311 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
312 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
315 if args.data_size > 0:
316 train_iter = itertools.islice(train_iter, args.data_size)
319 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
320 yield self.tokenizer(l)
322 self.vocab = torchtext.vocab.build_vocab_from_iterator(
324 specials = [ '<unk>', '<nul>' ],
325 min_freq = self.min_freq
328 self.vocab.set_default_index(self.vocab[ '<unk>' ])
330 # makes a tensor from a list of list of tokens
331 def tensorize(self, s):
332 a = max(len(x) for x in s)
333 return torch.tensor([ self.vocab(x + [ '<nul>' ] * (a - len(x))) for x in s ])
335 def yield_batches(self, ds):
338 q = self.tokenizer(l)
339 if len(q) >= self.len_min and len(q) <= self.len_max:
341 if len(s) == self.batch_size:
342 yield self.tensorize(s)
346 yield self.tensorize(s)
348 def batches(self, split = 'train'):
349 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
352 if args.data_size > 0:
353 data_iter = itertools.islice(data_iter, args.data_size)
355 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
357 def vocabulary_size(self):
358 return len(self.vocab)
360 def produce_results(self, n_epoch, model):
362 file_name = f'result_wiki103_{n_epoch:04d}.txt'
364 with open(file_name, 'w') as outfile:
366 'the cat is hunting a',
367 'paris is the capital',
368 'cars are convenient',
369 'the difference between men and women is',
370 'the object was blue all over and green all over it was',
371 'cherries are red and lemons are',
372 'cherries are sweet and lemons are',
373 'two plus three equals',
376 t_primer = self.tokenizer(primer)
379 for j in range(nb_tokens):
381 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
382 input = F.pad(input, (0, 1)) # Add the next token, the one to predict
383 output = model(input)
384 logits = output[0, -1]
385 if args.deterministic_synthesis:
386 t_next = logits.argmax()
388 dist = torch.distributions.categorical.Categorical(logits = logits)
389 t_next = dist.sample()
390 t_generated.append(self.vocab.lookup_token(t_next))
391 if t_generated[-1] == '<nul>': break
393 s = ' '.join(t_generated)
395 outfile.write(f'<{primer}> {s}\n')
397 log_string(f'wrote {file_name}')
399 ######################################################################
401 class TaskMNIST(Task):
403 def __init__(self, batch_size, device = torch.device('cpu')):
405 self.batch_size = batch_size
407 def batches(self, split = 'train'):
408 assert split in { 'train', 'test' }
409 data_set = torchvision.datasets.MNIST(
410 root = './data', train = (split == 'train'),
413 data_input = data_set.data.view(-1, 28 * 28).long()
414 if args.data_size >= 0:
415 data_input = data_input[:args.data_size]
416 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
419 def vocabulary_size(self):
422 def produce_results(self, n_epoch, model):
424 results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
425 image_name = f'result_mnist_{n_epoch:04d}.png'
426 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
427 image_name, nrow = 16, pad_value = 0.8)
428 log_string(f'wrote {image_name}')
430 ######################################################################
432 log_string(f'device {device}')
434 if args.data == 'wiki103':
435 nb_epochs_default = 10
436 task = TaskWiki103(batch_size = args.batch_size, device = device)
437 elif args.data == 'mnist':
438 nb_epochs_default = 25
439 task = TaskMNIST(batch_size = args.batch_size, device = device)
440 elif args.data == 'picoclvr':
441 nb_epochs_default = 10
442 task = TaskPicoCLVR(batch_size = args.batch_size,
443 height = args.picoclvr_height,
444 width = args.picoclvr_width,
445 nb_colors = args.picoclvr_nb_colors,
448 raise ValueError(f'Unknown dataset {args.data}.')
450 vocabulary_size = task.vocabulary_size()
452 log_string(f'vocabulary_size {vocabulary_size}')
454 ##############################
457 vocabulary_size = vocabulary_size,
458 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
459 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
464 nb_parameters = sum(p.numel() for p in model.parameters())
465 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
467 ######################################################################
469 nb_epochs_finished = 0
471 if args.no_checkpoint:
472 log_string(f'not trying to load checkpoint.')
476 checkpoint = torch.load(args.checkpoint_name)
477 nb_epochs_finished = checkpoint['nb_epochs_finished']
478 model.load_state_dict(checkpoint['model_state'])
479 torch.set_rng_state(checkpoint['rng_state'])
480 if torch.cuda.is_available():
481 torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
482 log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.')
484 except FileNotFoundError:
485 log_string('starting from scratch.')
488 log_string('error when loading the checkpoint.')
491 ######################################################################
493 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
496 for input in task.batches(split = 'train'):
497 token_count += F.one_hot(input, num_classes = task.vocabulary_size()).sum((0, 1))
498 token_probas = token_count / token_count.sum()
499 entropy = -torch.xlogy(token_probas, token_probas).sum()
500 train_set_perplexity = math.exp(entropy)
502 for n_epoch in range(nb_epochs_finished, nb_epochs):
504 if args.learning_rate_end < 0:
505 lr = args.learning_rate
507 u = n_epoch / (nb_epochs - 1)
508 lr = math.exp((1 - u) * math.log(args.learning_rate) +
509 u * math.log(args.learning_rate_end))
510 log_string(f'learning_rate {lr}')
512 if args.optim == 'sgd':
513 optimizer = torch.optim.SGD(model.parameters(), lr = lr)
514 elif args.optim == 'adam':
515 optimizer = torch.optim.Adam(model.parameters(), lr = lr)
516 elif args.optim == 'adamw':
517 optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
519 raise ValueError(f'Unknown optimizer {args.optim}.')
523 nb_train_samples, acc_train_loss = 0, 0.0
525 for input in task.batches(split = 'train'):
526 input = input.to(device)
527 output = model(input)
528 loss = F.cross_entropy(output.transpose(1, 2), input)
529 acc_train_loss += loss.item() * input.size(0)
530 nb_train_samples += input.size(0)
532 optimizer.zero_grad()
536 with torch.autograd.no_grad():
540 nb_test_samples, acc_test_loss = 0, 0.0
542 for input in task.batches(split = 'test'):
543 input = input.to(device)
544 output = model(input)
545 loss = F.cross_entropy(output.transpose(1, 2), input)
546 acc_test_loss += loss.item() * input.size(0)
547 nb_test_samples += input.size(0)
549 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
550 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
552 log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}')
554 task.produce_results(n_epoch, model)
557 'nb_epochs_finished': n_epoch + 1,
558 'model_state': model.state_dict(),
559 'rng_state': torch.get_rng_state(),
562 if torch.cuda.is_available():
563 checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()
565 torch.save(checkpoint, args.checkpoint_name)
567 ######################################################################