3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
14 ######################################################################
16 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 ######################################################################
20 parser = argparse.ArgumentParser(description = 'My own GPT.')
22 parser.add_argument('--log_filename',
23 type = str, default = 'train.log')
25 parser.add_argument('--download',
26 type = bool, default = False)
28 parser.add_argument('--seed',
29 type = int, default = 0)
31 parser.add_argument('--nb_epochs',
32 type = int, default = 100)
34 parser.add_argument('--batch_size',
35 type = int, default = 25)
37 parser.add_argument('--data',
38 type = str, default = 'wiki103')
40 parser.add_argument('--data_size',
41 type = int, default = -1)
43 parser.add_argument('--optim',
44 type = str, default = 'adam')
46 parser.add_argument('--learning_rate',
47 type = float, default = 1e-4)
49 parser.add_argument('--dim_model',
50 type = int, default = 512)
52 parser.add_argument('--dim_keys',
53 type = int, default = 64)
55 parser.add_argument('--dim_hidden',
56 type = int, default = 2048)
58 parser.add_argument('--nb_heads',
59 type = int, default = 8)
61 parser.add_argument('--nb_blocks',
62 type = int, default = 12)
64 parser.add_argument('--dropout',
65 type = float, default = 0.1)
67 parser.add_argument('--synthesis_sampling',
68 type = bool, default = True)
70 ######################################################################
72 args = parser.parse_args()
74 log_file = open(args.log_filename, 'w')
77 torch.manual_seed(args.seed)
79 ######################################################################
82 t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
84 if log_file is not None:
85 log_file.write(t + s + '\n')
92 log_string(f'args.{n} {getattr(args, n)}')
94 ##############################
96 class Residual(nn.Module):
97 def __init__(self, *f):
99 self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
101 def forward(self, x):
104 ##############################
106 class PositionalEncoding(nn.Module):
107 def __init__(self, len_max):
109 self.len_max = len_max
111 # From Vaswani et al 2018
112 # PE_{t,2i} = sin(t/(L^{2i/D}))
113 # PE_{t,2i+1} = cos(t/(L^{2i/D}))
114 def forward(self, x):
115 t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
116 j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
118 return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]
120 ##############################
122 class QKVAttention(nn.Module):
123 def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0):
127 return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))
129 self.wq = randw(nb_heads, dim_qk, dim_in)
130 self.wk = randw(nb_heads, dim_qk, dim_in)
131 self.wv = randw(nb_heads, dim_v, dim_in)
133 self.attention_dropout = attention_dropout
135 def forward(self, x):
136 q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
137 k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
138 v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
139 r = math.sqrt(q.size(3))
140 a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)
142 mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
143 a = a.masked_fill(mask, float('-inf'))
144 a = a.softmax(dim = 3)
145 a = F.dropout(a, self.attention_dropout, self.training)
146 y = torch.einsum('nhts,nhsd->nhtd', a, v)
147 return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
149 ##############################
151 class MyGPT(nn.Module):
154 dim_model, dim_keys, dim_hidden,
155 nb_heads, nb_blocks, dropout = 0.):
159 assert dim_model % nb_heads == 0
161 self.embedding = nn.Sequential(
162 nn.Embedding(vocabulary_size, dim_model),
164 PositionalEncoding(len_max = 1e5),
169 for _ in range(nb_blocks):
172 nn.LayerNorm(dim_model),
175 dim_qk = dim_keys, dim_v = dim_model // nb_heads,
177 causal = True, attention_dropout = dropout
179 nn.Linear(in_features = dim_model, out_features = dim_model),
182 nn.LayerNorm(dim_model),
183 nn.Linear(in_features = dim_model, out_features = dim_hidden),
185 nn.Linear(in_features = dim_hidden, out_features = dim_model),
190 self.trunk = nn.Sequential(*trunk_blocks)
192 self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)
194 def forward(self, x):
195 x = self.embedding(x)
200 ######################################################################
203 def batches(self, split = 'train'):
206 def vocabulary_size(self):
209 def produce_results(self, n_epoch, model, nb_tokens = 50):
212 ######################################################################
216 class TaskPicoCLVR(Task):
218 def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')):
219 self.batch_size = batch_size
221 nb = args.data_size if args.data_size > 0 else 250000
223 descr = picoclvr.generate(nb, height = height, width = width)
224 descr = [ s.strip().split(' ') for s in descr ]
225 l = max([ len(s) for s in descr ])
226 descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
230 for t in s: tokens.add(t)
231 self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
232 self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
234 t = [ [ self.token2id[u] for u in s ] for s in descr ]
235 data_input = torch.tensor(t, device = self.device)
237 self.test_input = data_input[:nb // 5]
238 self.train_input = data_input[nb // 5:]
240 def batches(self, split = 'train'):
241 assert split in { 'train', 'test' }
243 for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'):
246 for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'):
249 def vocabulary_size(self):
250 return len(self.token2id)
252 def produce_results(self, n_epoch, model, nb_tokens = 50):
257 'red above green <sep> green top <sep> blue right of red <img>',
258 'there is red <sep> there is yellow <sep> there is blue <img>',
259 'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
260 'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
263 for k in range(nb_per_primer):
264 t_primer = primer.strip().split(' ')
267 for j in range(nb_tokens):
268 t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
269 input = torch.tensor(t, device = self.device)
270 output = model(input)
271 logits = output[0, -1]
272 if args.synthesis_sampling:
273 dist = torch.distributions.categorical.Categorical(logits = logits)
277 t_generated.append(self.id2token[t.item()])
279 descr = [ ' '.join(t_primer + t_generated) ]
280 img += [ picoclvr.descr2img(descr) ]
282 img = torch.cat(img, 0)
283 file_name = f'result_picoclvr_{n_epoch:04d}.png'
284 torchvision.utils.save_image(img / 255.,
285 file_name, nrow = nb_per_primer, pad_value = 0.8)
286 log_string(f'wrote {file_name}')
288 ######################################################################
290 class TaskWiki103(Task):
292 def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
293 device = torch.device('cpu')):
295 self.batch_size = batch_size
296 self.len_min = len_min
297 self.len_max = len_max
298 self.min_freq = min_freq
301 self.tokenizer = torchtext.data.get_tokenizer('basic_english')
302 train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
305 if args.data_size > 0:
306 train_iter = itertools.islice(train_iter, args.data_size)
309 for l in tqdm.tqdm(train_iter, desc = 'vocab'):
310 yield self.tokenizer(l)
312 self.vocab = torchtext.vocab.build_vocab_from_iterator(
314 specials = [ '<unk>', '<non>' ],
315 min_freq = self.min_freq
318 self.vocab.set_default_index(self.vocab[ '<unk>' ])
320 def tensorize(self, s):
321 a = max(len(x) for x in s)
322 return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
324 def yield_batches(self, ds):
327 q = self.tokenizer(l)
328 if len(q) >= self.len_min and len(q) <= self.len_max:
330 if len(s) == self.batch_size:
331 yield self.tensorize(s)
335 yield self.tensorize(s)
337 def batches(self, split = 'train'):
338 data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
341 if args.data_size > 0:
342 data_iter = itertools.islice(data_iter, args.data_size)
344 return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch'))
346 def vocabulary_size(self):
347 return len(self.vocab)
349 def produce_results(self, n_epoch, model, nb_tokens = 50):
350 file_name = f'result_wiki103_{n_epoch:04d}.txt'
352 with open(file_name, 'w') as outfile:
354 'the cat is hunting a',
355 'paris is the capital',
356 'cars are convenient',
357 'the difference between men and women is',
358 'the object was blue all over and green all over it was',
359 'cherries are red and lemons are',
360 'cherries are sweet and lemons are',
361 'two plus three equals',
364 t_primer = self.tokenizer(primer)
367 for j in range(nb_tokens):
369 input = self.tensorize([ t_primer + t_generated ]).to(self.device)
370 output = model(input)
371 logits = output[0, -1]
372 if args.synthesis_sampling:
373 dist = torch.distributions.categorical.Categorical(logits = logits)
377 t_generated.append(self.vocab.lookup_token(t))
378 if t_generated[-1] == '<non>': break
380 s = ' '.join(t_generated)
382 outfile.write(f'<{primer}> {s}\n')
384 log_string(f'wrote {file_name}')
386 ######################################################################
388 class TaskMNIST(Task):
390 def __init__(self, batch_size, device = torch.device('cpu')):
392 self.batch_size = batch_size
394 def batches(self, split = 'train'):
395 assert split in { 'train', 'test' }
396 data_set = torchvision.datasets.MNIST(
397 root = './data', train = (split == 'train'),
400 data_input = data_set.data.view(-1, 28 * 28).long()
401 if args.data_size >= 0:
402 data_input = data_input[:args.data_size]
403 for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'):
406 def vocabulary_size(self):
409 def produce_results(self, n_epoch, model, nb_samples = 64):
410 results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
411 for input in results.split(self.batch_size):
412 for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
413 output = model(input)
414 logits = output[:, s]
415 if args.synthesis_sampling:
416 dist = torch.distributions.categorical.Categorical(logits = logits)
422 image_name = f'result_mnist_{n_epoch:04d}.png'
423 torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
424 image_name, nrow = 16, pad_value = 0.8)
425 log_string(f'wrote {image_name}')
427 ######################################################################
429 def check_causality(model):
431 input = torch.rand(1, 5, dim_model).requires_grad_()
433 a = torch.zeros(output.size(1), input.size(1))
434 for k in range(output.size(1)):
435 for d in range(output.size(2)):
436 g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
437 a[k] += g.squeeze(0).pow(2).sum(1)
440 ######################################################################
442 log_string(f'device {device}')
444 if args.data == 'wiki103':
445 task = TaskWiki103(batch_size = args.batch_size, device = device)
446 elif args.data == 'mnist':
447 task = TaskMNIST(batch_size = args.batch_size, device = device)
448 elif args.data == 'picoclvr':
449 task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
451 raise ValueError(f'Unknown dataset {args.data}.')
453 vocabulary_size = task.vocabulary_size()
455 log_string(f'vocabulary_size {vocabulary_size}')
457 ##############################
460 vocabulary_size = vocabulary_size,
461 dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
462 nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
465 nb_parameters = sum(p.numel() for p in model.parameters())
466 log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
470 ######################################################################
472 if args.optim == 'sgd':
473 optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
474 elif args.optim == 'adam':
475 optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
476 elif args.optim == 'adamw':
477 optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
479 raise ValueError(f'Unknown optimizer {args.optim}.')
481 for k in range(args.nb_epochs):
485 nb_train_samples, acc_train_loss = 0, 0.0
487 for input in task.batches(split = 'train'):
488 input = input.to(device)
489 output = model(input)
490 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
491 acc_train_loss += loss.item() * input.size(0)
492 nb_train_samples += input.size(0)
494 optimizer.zero_grad()
498 with torch.autograd.no_grad():
502 nb_test_samples, acc_test_loss = 0, 0.0
504 for input in task.batches(split = 'test'):
505 input = input.to(device)
506 output = model(input)
507 loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
508 acc_test_loss += loss.item() * input.size(0)
509 nb_test_samples += input.size(0)
511 train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
512 test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
514 log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
516 task.produce_results(k, model)
518 ######################################################################