3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 import math, sys, argparse, time, tqdm, itertools
10 import torch, torchtext, torchvision
12 from torch.nn import functional as F
16 ######################################################################
18 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 ######################################################################
21 parser = argparse.ArgumentParser(description="My own GPT.")
23 parser.add_argument("--log_filename", type=str, default="train.log")
25 parser.add_argument("--seed", type=int, default=0)
27 parser.add_argument("--nb_epochs", type=int, default=None)
29 parser.add_argument("--batch_size", type=int, default=25)
31 parser.add_argument("--data", type=str, default="wiki103")
33 parser.add_argument("--data_size", type=int, default=None)
35 parser.add_argument("--optim", type=str, default="adam")
37 parser.add_argument("--learning_rate", type=float, default=1e-3)
39 parser.add_argument("--learning_rate_end", type=float, default=1e-6)
41 parser.add_argument("--dim_model", type=int, default=None)
43 parser.add_argument("--dim_keys", type=int, default=None)
45 parser.add_argument("--dim_hidden", type=int, default=None)
47 parser.add_argument("--nb_heads", type=int, default=None)
49 parser.add_argument("--nb_blocks", type=int, default=None)
51 parser.add_argument("--dropout", type=float, default=0.1)
53 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
55 parser.add_argument("--no_checkpoint", action="store_true", default=False)
57 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
59 ##############################
62 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
64 parser.add_argument("--picoclvr_height", type=int, default=12)
66 parser.add_argument("--picoclvr_width", type=int, default=16)
68 ######################################################################
70 args = parser.parse_args()
72 log_file = open(args.log_filename, "w")
75 torch.manual_seed(args.seed)
77 ######################################################################
81 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
83 if log_file is not None:
84 log_file.write(t + s + "\n")
92 log_string(f"args.{n} {getattr(args, n)}")
94 ######################################################################
132 if args.data in default_args:
133 for k, v in default_args[args.data].items():
134 if getattr(args, k) is None:
137 ######################################################################
144 nb_tokens_to_generate,
146 device=torch.device("cpu"),
148 results = torch.zeros(
149 nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device
155 first = primer.size(1)
156 results = torch.cat((primer, results), 1)
158 for input in results.split(batch_size):
159 for s in range(first, input.size(1)):
160 output = model(input)
161 logits = output[:, s]
162 if args.deterministic_synthesis:
163 t_next = logits.argmax(1)
165 dist = torch.distributions.categorical.Categorical(logits=logits)
166 t_next = dist.sample()
172 ######################################################################
176 def batches(self, split="train"):
179 def vocabulary_size(self):
182 def produce_results(self, n_epoch, model):
186 ######################################################################
191 class TaskPicoCLVR(Task):
193 # Make a tensor from a list of strings
194 def tensorize(self, descr):
195 token_descr = [s.strip().split(" ") for s in descr]
196 l = max([len(s) for s in token_descr])
197 padded_token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
198 id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr]
199 return torch.tensor(id_descr, device=self.device)
201 def trim(self, x, token="<nul>"):
202 n = self.token2id[token]
203 i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
204 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
208 self, batch_size, height, width, nb_colors=5, device=torch.device("cpu")
210 def generate_descr(nb):
211 return picoclvr.generate(
212 nb, height=self.height, width=self.width, nb_colors=nb_colors
217 self.batch_size = batch_size
219 nb = args.data_size if args.data_size is not None else 250000
221 log_string(f"generating {nb} samples (can take some time)")
222 self.train_descr = generate_descr((nb * 4) // 5)
223 self.test_descr = generate_descr((nb * 1) // 5)
225 # Build the tokenizer
227 for d in [self.train_descr, self.test_descr]:
229 for t in s.strip().split(" "):
231 self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
232 self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
234 # Tokenize the train and test sets
235 self.train_input = self.tensorize(self.train_descr)
236 self.test_input = self.tensorize(self.test_descr)
238 def batches(self, split="train"):
239 assert split in {"train", "test"}
240 input = self.train_input if split == "train" else self.test_input
241 for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"):
242 yield self.trim(batch)
244 def vocabulary_size(self):
245 return len(self.token2id)
248 self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
250 nb_tokens_to_generate = self.height * self.width + 3
253 for primer_descr in primers_descr:
255 results = autoregression(
258 nb_samples=nb_per_primer,
259 nb_tokens_to_generate=nb_tokens_to_generate,
260 primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
264 l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
267 np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
269 nb_requested_properties, _, nb_missing_properties = zip(*np)
272 f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
275 np = torch.tensor(np)
276 count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
277 for i in range(count.size(0)):
278 for j in range(count.size(1)):
279 count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
283 picoclvr.descr2img(d, height=self.height, width=self.width)
284 for d in result_descr
287 img = torch.cat(img, 0)
288 image_name = f"result_picoclvr_{n_epoch:04d}.png"
289 torchvision.utils.save_image(
290 img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
292 log_string(f"wrote {image_name}")
296 def produce_results(self, n_epoch, model):
298 "red above green <sep> green top <sep> blue right of red <img>",
299 "there is red <sep> there is yellow <sep> there is blue <img>",
300 "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
301 "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
305 n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
310 # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
312 # count=self.test_model(
314 # test_primers_descr,
315 # nb_per_primer=1, generate_images=False
318 # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
319 # for i in range(count.size(0)):
320 # for j in range(count.size(1)):
321 # f.write(f'{count[i,j]}')
322 # f.write(" " if j<count.size(1)-1 else "\n")
325 ######################################################################
328 class TaskWiki103(Task):
335 device=torch.device("cpu"),
338 self.batch_size = batch_size
339 self.len_min = len_min
340 self.len_max = len_max
341 self.min_freq = min_freq
344 self.tokenizer = torchtext.data.get_tokenizer("basic_english")
345 train_iter = torchtext.datasets.WikiText103(split="train", root="./data/nlp/")
348 if args.data_size is not None:
349 train_iter = itertools.islice(train_iter, args.data_size)
352 for l in tqdm.tqdm(train_iter, desc="vocab"):
353 yield self.tokenizer(l)
355 self.vocab = torchtext.vocab.build_vocab_from_iterator(
356 yield_tokens(), specials=["<unk>", "<nul>"], min_freq=self.min_freq
359 self.vocab.set_default_index(self.vocab["<unk>"])
361 # makes a tensor from a list of list of tokens
362 def tensorize(self, s):
363 a = max(len(x) for x in s)
364 return torch.tensor([self.vocab(x + ["<nul>"] * (a - len(x))) for x in s])
366 def yield_batches(self, ds):
369 q = self.tokenizer(l)
370 if len(q) >= self.len_min and len(q) <= self.len_max:
372 if len(s) == self.batch_size:
373 yield self.tensorize(s)
377 yield self.tensorize(s)
379 def batches(self, split="train"):
380 data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/")
383 if args.data_size is not None:
384 data_iter = itertools.islice(data_iter, args.data_size)
386 return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}"))
388 def vocabulary_size(self):
389 return len(self.vocab)
391 def produce_results(self, n_epoch, model):
393 file_name = f"result_wiki103_{n_epoch:04d}.txt"
395 with open(file_name, "w") as outfile:
397 "the cat is hunting a",
398 "paris is the capital",
399 "cars are convenient",
400 "the difference between men and women is",
401 "the object was blue all over and green all over it was",
402 "cherries are red and lemons are",
403 "cherries are sweet and lemons are",
404 "two plus three equals",
407 t_primer = self.tokenizer(primer)
410 for j in range(nb_tokens):
412 input = self.tensorize([t_primer + t_generated]).to(self.device)
415 ) # Add the next token, the one to predict
416 output = model(input)
417 logits = output[0, -1]
418 if args.deterministic_synthesis:
419 t_next = logits.argmax()
421 dist = torch.distributions.categorical.Categorical(
424 t_next = dist.sample()
425 t_generated.append(self.vocab.lookup_token(t_next))
426 if t_generated[-1] == "<nul>":
429 s = " ".join(t_generated)
431 outfile.write(f"<{primer}> {s}\n")
433 log_string(f"wrote {file_name}")
436 ######################################################################
439 class TaskMNIST(Task):
440 def __init__(self, batch_size, device=torch.device("cpu")):
442 self.batch_size = batch_size
444 def batches(self, split="train"):
445 assert split in {"train", "test"}
446 data_set = torchvision.datasets.MNIST(
447 root="./data", train=(split == "train"), download=True
449 data_input = data_set.data.view(-1, 28 * 28).long()
450 if args.data_size is not None:
451 data_input = data_input[: args.data_size]
452 for batch in tqdm.tqdm(
453 data_input.split(self.batch_size), desc=f"epoch-{split}"
457 def vocabulary_size(self):
460 def produce_results(self, n_epoch, model):
462 results = autoregression(
463 model, self.batch_size, nb_samples, 28 * 28, device=self.device
465 image_name = f"result_mnist_{n_epoch:04d}.png"
466 torchvision.utils.save_image(
467 1 - results.reshape(-1, 1, 28, 28) / 255.0,
472 log_string(f"wrote {image_name}")
475 ######################################################################
477 log_string(f"device {device}")
479 if args.data == "wiki103":
480 task = TaskWiki103(batch_size=args.batch_size, device=device)
481 elif args.data in {"mnist", "mnist-debug"}:
482 task = TaskMNIST(batch_size=args.batch_size, device=device)
483 elif args.data == "picoclvr":
485 batch_size=args.batch_size,
486 height=args.picoclvr_height,
487 width=args.picoclvr_width,
488 nb_colors=args.picoclvr_nb_colors,
492 raise ValueError(f"Unknown dataset {args.data}.")
494 vocabulary_size = task.vocabulary_size()
496 log_string(f"vocabulary_size {vocabulary_size}")
498 ##############################
501 vocabulary_size=vocabulary_size,
502 dim_model=args.dim_model,
503 dim_keys=args.dim_keys,
504 dim_hidden=args.dim_hidden,
505 nb_heads=args.nb_heads,
506 nb_blocks=args.nb_blocks,
507 dropout=args.dropout,
512 nb_parameters = sum(p.numel() for p in model.parameters())
513 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
515 ######################################################################
517 nb_epochs_finished = 0
519 if args.no_checkpoint:
520 log_string(f"not trying to load checkpoint.")
524 checkpoint = torch.load(args.checkpoint_name)
525 nb_epochs_finished = checkpoint["nb_epochs_finished"]
526 model.load_state_dict(checkpoint["model_state"])
527 torch.set_rng_state(checkpoint["rng_state"])
528 if torch.cuda.is_available():
529 torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
530 log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
532 except FileNotFoundError:
533 log_string("starting from scratch.")
536 log_string("error when loading the checkpoint.")
539 ######################################################################
542 for input in task.batches(split="train"):
543 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
544 token_probas = token_count / token_count.sum()
545 entropy = -torch.xlogy(token_probas, token_probas).sum()
546 train_set_perplexity = math.exp(entropy)
548 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
550 if args.learning_rate_end < 0:
551 lr = args.learning_rate
553 u = n_epoch / (args.nb_epochs - 1)
555 (1 - u) * math.log(args.learning_rate)
556 + u * math.log(args.learning_rate_end)
558 log_string(f"learning_rate {lr}")
560 if args.optim == "sgd":
561 optimizer = torch.optim.SGD(model.parameters(), lr=lr)
562 elif args.optim == "adam":
563 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
564 elif args.optim == "adamw":
565 optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
567 raise ValueError(f"Unknown optimizer {args.optim}.")
571 nb_train_samples, acc_train_loss = 0, 0.0
573 for input in task.batches(split="train"):
574 input = input.to(device)
575 output = model(input)
576 loss = F.cross_entropy(output.transpose(1, 2), input)
577 acc_train_loss += loss.item() * input.size(0)
578 nb_train_samples += input.size(0)
580 optimizer.zero_grad()
584 with torch.autograd.no_grad():
588 nb_test_samples, acc_test_loss = 0, 0.0
590 for input in task.batches(split="test"):
591 input = input.to(device)
592 output = model(input)
593 loss = F.cross_entropy(output.transpose(1, 2), input)
594 acc_test_loss += loss.item() * input.size(0)
595 nb_test_samples += input.size(0)
597 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
598 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
601 f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
604 task.produce_results(n_epoch, model)
607 "nb_epochs_finished": n_epoch + 1,
608 "model_state": model.state_dict(),
609 "rng_state": torch.get_rng_state(),
612 if torch.cuda.is_available():
613 checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
615 torch.save(checkpoint, args.checkpoint_name)
617 ######################################################################