3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 ######################################################################
22 parser = argparse.ArgumentParser(description = 'My own GPT.')
24 parser.add_argument('--log_filename',
25 type = str, default = 'train.log')
27 parser.add_argument('--download',
28 type = bool, default = False)
30 parser.add_argument('--seed',
31 type = int, default = 0)
33 parser.add_argument('--nb_epochs',
34 type = int, default = 100)
36 parser.add_argument('--batch_size',
37 type = int, default = 25)
39 parser.add_argument('--data',
40 type = str, default = 'wiki103')
42 parser.add_argument('--data_size',
43 type = int, default = -1)
45 parser.add_argument('--optim',
46 type = str, default = 'adam')
48 parser.add_argument('--learning_rate',
49 type = float, default = 1e-4)
51 parser.add_argument('--dim_model',
52 type = int, default = 512)
54 parser.add_argument('--dim_keys',
55 type = int, default = 64)
57 parser.add_argument('--dim_hidden',
58 type = int, default = 2048)
60 parser.add_argument('--nb_heads',
61 type = int, default = 8)
63 parser.add_argument('--nb_blocks',
64 type = int, default = 12)
66 parser.add_argument('--dropout',
67 type = float, default = 0.1)
69 parser.add_argument('--synthesis_sampling',
70 type = bool, default = True)
72 ######################################################################
74 args = parser.parse_args()
76 log_file = open(args.log_filename, 'w')
79 torch.manual_seed(args.seed)
81 ######################################################################
84 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
86 if log_file is not None:
87 log_file.write(t + s + '\n')
94 log_string(f'args.{n} {getattr(args, n)}')
96 ######################################################################
99 def batches(self, split = 'train'):
102 def vocabulary_size(self):
105 def produce_results(self, n_epoch, model, nb_tokens = 50):
108 ######################################################################
112 class TaskPicoCLVR(Task):
114 def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
115 self.batch_size = batch_size
117 nb = args.data_size if args.data_size > 0 else 250000
119 descr = picoclvr.generate(nb, height = height, width = width)
120 descr = [ s.strip().split(' ') for s in descr ]
121 l = max([ len(s) for s in descr ])
122 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
126 for t in s: tokens.add(t)
127 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
128 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
130 t = [ [ self.token2id[u] for u in s ] for s in descr ]
131 data_input = torch.tensor(t, device = self.device)
133 self.test_input = data_input[:nb // 5]
134 self.train_input = data_input[nb // 5:]
136 def batches(self, split = 'train'):
137 assert split in { 'train', 'test' }
139 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
142 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
145 def vocabulary_size(self):
146 return len(self.token2id)
148 def produce_results(self, n_epoch, model, nb_tokens = 50):
153 'red above green <sep> green top <sep> blue right of red <img>',
154 'there is red <sep> there is yellow <sep> there is blue <img>',
155 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
156 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
159 for k in range(nb_per_primer):
160 t_primer = primer.strip().split(' ')
163 for j in range(nb_tokens):
164 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
165 input = torch.tensor(t, device = self.device)
166 output = model(input)
167 logits = output[0, -1]
168 if args.synthesis_sampling:
169 dist = torch.distributions.categorical.Categorical(logits = logits)
173 t_generated.append(self.id2token[t.item()])
175 descr = [ ' '.join(t_primer + t_generated) ]
176 img += [ picoclvr.descr2img(descr) ]
178 img = torch.cat(img, 0)
179 file_name = f'result_picoclvr_{n_epoch:04d}.png'
180 torchvision.utils.save_image(img / 255.,
181 file_name, nrow = nb_per_primer, pad_value = 0.8)
182 log_string(f'wrote {file_name}')
184 ######################################################################
186 class TaskWiki103(Task):
188 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
189 device = torch.device('cpu')):
191 self.batch_size = batch_size
192 self.len_min = len_min
193 self.len_max = len_max
194 self.min_freq = min_freq
197 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
198 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
201 if args.data_size > 0:
202 train_iter = itertools.islice(train_iter, args.data_size)
205 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
206 yield self.tokenizer(l)
208 self.vocab = torchtext.vocab.build_vocab_from_iterator(
210 specials = [ '<unk>', '<non>' ],
211 min_freq = self.min_freq
214 self.vocab.set_default_index(self.vocab[ '<unk>' ])
216 def tensorize(self, s):
217 a = max(len(x) for x in s)
218 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
220 def yield_batches(self, ds):
223 q = self.tokenizer(l)
224 if len(q) >= self.len_min and len(q) <= self.len_max:
226 if len(s) == self.batch_size:
227 yield self.tensorize(s)
231 yield self.tensorize(s)
233 def batches(self, split = 'train'):
234 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
237 if args.data_size > 0:
238 data_iter = itertools.islice(data_iter, args.data_size)
240 return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
242 def vocabulary_size(self):
243 return len(self.vocab)
245 def produce_results(self, n_epoch, model, nb_tokens = 50):
246 file_name = f'result_wiki103_{n_epoch:04d}.txt'
248 with open(file_name, 'w') as outfile:
250 'the cat is hunting a',
251 'paris is the capital',
252 'cars are convenient',
253 'the difference between men and women is',
254 'the object was blue all over and green all over it was',
255 'cherries are red and lemons are',
256 'cherries are sweet and lemons are',
257 'two plus three equals',
260 t_primer = self.tokenizer(primer)
263 for j in range(nb_tokens):
265 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
266 output = model(input)
267 logits = output[0, -1]
268 if args.synthesis_sampling:
269 dist = torch.distributions.categorical.Categorical(logits = logits)
273 t_generated.append(self.vocab.lookup_token(t))
274 if t_generated[-1] == '<non>': break
276 s = ' '.join(t_generated)
278 outfile.write(f'<{primer}> {s}\n')
280 log_string(f'wrote {file_name}')
282 ######################################################################
284 class TaskMNIST(Task):
286 def __init__(self, batch_size, device = torch.device('cpu')):
288 self.batch_size = batch_size
290 def batches(self, split = 'train'):
291 assert split in { 'train', 'test' }
292 data_set = torchvision.datasets.MNIST(
293 root = './data', train = (split == 'train'),
296 data_input = data_set.data.view(-1, 28 * 28).long()
297 if args.data_size >= 0:
298 data_input = data_input[:args.data_size]
299 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
302 def vocabulary_size(self):
305 def produce_results(self, n_epoch, model, nb_samples = 64):
306 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
307 for input in results.split(self.batch_size):
308 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
309 output = model(input)
310 logits = output[:, s]
311 if args.synthesis_sampling:
312 dist = torch.distributions.categorical.Categorical(logits = logits)
318 image_name = f'result_mnist_{n_epoch:04d}.png'
319 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
320 image_name, nrow = 16, pad_value = 0.8)
321 log_string(f'wrote {image_name}')
323 ######################################################################
325 def check_causality(model):
327 input = torch.rand(1, 5, dim_model).requires_grad_()
329 a = torch.zeros(output.size(1), input.size(1))
330 for k in range(output.size(1)):
331 for d in range(output.size(2)):
332 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
333 a[k] += g.squeeze(0).pow(2).sum(1)
336 ######################################################################
338 log_string(f'device {device}')
340 if args.data == 'wiki103':
341 task = TaskWiki103(batch_size = args.batch_size, device = device)
342 elif args.data == 'mnist':
343 task = TaskMNIST(batch_size = args.batch_size, device = device)
344 elif args.data == 'picoclvr':
345 task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
347 raise ValueError(f'Unknown dataset {args.data}.')
349 vocabulary_size = task.vocabulary_size()
351 log_string(f'vocabulary_size {vocabulary_size}')
353 ##############################
356 vocabulary_size = vocabulary_size,
357 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
358 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
361 nb_parameters = sum(p.numel() for p in model.parameters())
362 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
366 ######################################################################
368 if args.optim == 'sgd':
369 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
370 elif args.optim == 'adam':
371 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
372 elif args.optim == 'adamw':
373 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
375 raise ValueError(f'Unknown optimizer {args.optim}.')
377 for k in range(args.nb_epochs):
381 nb_train_samples, acc_train_loss = 0, 0.0
383 for input in task.batches(split = 'train'):
384 input = input.to(device)
385 output = model(input)
386 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
387 acc_train_loss += loss.item() * input.size(0)
388 nb_train_samples += input.size(0)
390 optimizer.zero_grad()
394 with torch.autograd.no_grad():
398 nb_test_samples, acc_test_loss = 0, 0.0
400 for input in task.batches(split = 'test'):
401 input = input.to(device)
402 output = model(input)
403 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
404 acc_test_loss += loss.item() * input.size(0)
405 nb_test_samples += input.size(0)
407 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
408 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
410 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
412 task.produce_results(k, model)
414 ######################################################################