3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
6 # Written by Francois Fleuret <francois@fleuret.org>
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
11 import math, sys, argparse, time, tqdm, os
13 import torch, torchvision
15 from torch.nn import functional as F
17 import mygpt, tensorstack
19 ######################################################################
21 if torch.cuda.is_available():
22 device = torch.device("cuda")
23 torch.backends.cuda.matmul.allow_tf32 = True
25 device = torch.device("cpu")
27 ######################################################################
29 parser = argparse.ArgumentParser(
30 description="An implementation of GPT with cache.",
31 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
34 parser.add_argument("--task", type=str, default="picoclvr")
36 parser.add_argument("--log_filename", type=str, default="train.log")
38 parser.add_argument("--result_dir", type=str, default="results_default")
40 parser.add_argument("--seed", type=int, default=0)
42 parser.add_argument("--nb_epochs", type=int, default=25)
44 parser.add_argument("--batch_size", type=int, default=None)
46 parser.add_argument("--nb_train_samples", type=int, default=250000)
48 parser.add_argument("--nb_test_samples", type=int, default=10000)
50 parser.add_argument("--optim", type=str, default="adam")
52 parser.add_argument("--learning_rate", type=float, default=1e-4)
54 parser.add_argument("--learning_rate_schedule", type=str, default="10: 2e-5,30: 4e-6")
56 parser.add_argument("--dim_model", type=int, default=512)
58 parser.add_argument("--dim_keys", type=int, default=64)
60 parser.add_argument("--dim_hidden", type=int, default=2048)
62 parser.add_argument("--nb_heads", type=int, default=8)
64 parser.add_argument("--nb_blocks", type=int, default=12)
66 parser.add_argument("--dropout", type=float, default=0.1)
68 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
70 parser.add_argument("--no_checkpoint", action="store_true", default=False)
72 parser.add_argument("--overwrite_results", action="store_true", default=False)
74 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
76 ##############################
79 parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
81 parser.add_argument("--picoclvr_height", type=int, default=12)
83 parser.add_argument("--picoclvr_width", type=int, default=16)
85 parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
87 ##############################
90 parser.add_argument("--maze_height", type=int, default=13)
92 parser.add_argument("--maze_width", type=int, default=21)
94 parser.add_argument("--maze_nb_walls", type=int, default=15)
96 ##############################
99 parser.add_argument("--snake_height", type=int, default=6)
101 parser.add_argument("--snake_width", type=int, default=8)
103 parser.add_argument("--snake_nb_colors", type=int, default=3)
105 parser.add_argument("--snake_length", type=int, default=400)
107 ######################################################################
109 args = parser.parse_args()
111 assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
114 os.mkdir(args.result_dir)
115 except FileExistsError:
116 if not args.overwrite_results:
117 print(f"result directory {args.result_dir} already exists")
120 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
123 # torch.backends.cudnn.deterministic = True
124 # torch.backends.cudnn.benchmark = False
125 # torch.use_deterministic_algorithms(True)
126 torch.manual_seed(args.seed)
127 if torch.cuda.is_available():
128 torch.cuda.manual_seed_all(args.seed)
130 ######################################################################
147 if args.task in default_args:
148 for k, v in default_args[args.task].items():
149 if getattr(args, k) is None:
152 ######################################################################
156 t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
158 if log_file is not None:
159 log_file.write(t + s + "\n")
167 log_string(f"args.{n} {getattr(args, n)}")
169 ######################################################################
172 def masked_inplace_autoregression(
173 model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
175 for input, ar_mask in tqdm.tqdm(
176 zip(input.split(batch_size), ar_mask.split(batch_size)),
178 desc="autoregression",
179 total=input.size(0) // batch_size,
181 i = (ar_mask.sum(0) > 0).nonzero()
184 mygpt.BracketedSequence(input, 0, i.min())
185 ) # Needed to initialize the model's cache
186 for s in range(i.min(), i.max() + 1):
187 output = model(mygpt.BracketedSequence(input, s, 1)).x
188 logits = output[:, s]
189 if forbidden_tokens is not None:
190 logits = logits.masked_fill(forbidden_tokens, float("-inf"))
191 if args.deterministic_synthesis:
192 t_next = logits.argmax(1)
194 dist = torch.distributions.categorical.Categorical(logits=logits)
195 t_next = dist.sample()
196 input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
199 ######################################################################
203 def batches(self, split="train"):
206 def vocabulary_size(self):
209 def produce_results(self, n_epoch, model):
213 ######################################################################
218 class TaskPicoCLVR(Task):
219 # Make a tensor from a list of strings
220 def tensorize(self, descr):
221 token_descr = [s.strip().split(" ") for s in descr]
222 l = max([len(s) for s in token_descr])
223 token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
224 id_descr = [[self.token2id[u] for u in s] for s in token_descr]
225 return torch.tensor(id_descr, device=self.device)
227 # Make a list of strings from a tensor
228 def detensorize(self, x):
229 return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
231 # trim all the tensors in the tuple z to remove as much token from
232 # left and right in the first tensor. If z is a tuple, all its
233 # elements are trimed according to the triming for the first
234 def trim(self, z, token="<nul>"):
235 n = self.token2id[token]
238 i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
239 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
240 return tuple([t[:, a:b] for t in z])
242 i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
243 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
246 ######################
247 # Not the cleanest part of the code
249 # Extract the last image of each sequence, from the last <img>
250 # included, and set to <nul> all the tokens from the beginning of
251 # that image to the end
252 def excise_last_image(self, input):
253 t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
254 nb_img_tokens = self.height * self.width + 1
256 input = input.clone()
257 t = (input == t_img).long()
258 tail_masks = (t.cumsum(dim=1) == t.sum(dim=1, keepdim=True)).long()
259 i = (t * tail_masks).nonzero(as_tuple=True)
262 i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
264 images = self.trim(input[j])
266 loss_masks = 1 - tail_masks
267 input, loss_masks = self.trim((input, loss_masks))
268 return input, loss_masks, images
270 def add_true_image(self, input, images, loss_masks):
271 t_nul = self.token2id["<nul>"]
272 nb_img_tokens = self.height * self.width + 1
273 input = F.pad(input, (0, nb_img_tokens), value=t_nul)
274 loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
275 t = (input == t_nul).long()
276 i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
279 i[1][:, None] + torch.arange(nb_img_tokens, device=input.device)[None, :],
283 input, loss_masks = self.trim((input, loss_masks))
284 return input, loss_masks
286 def add_generated_image(self, input, loss_masks, model):
287 t_img, t_nul = self.token2id["<img>"], self.token2id["<nul>"]
288 nb_img_tokens = self.height * self.width + 1
290 input = F.pad(input, (0, nb_img_tokens), value=t_nul)
291 loss_masks = F.pad(loss_masks, (0, nb_img_tokens), value=0)
292 t = (input == t_nul).long()
293 i = (t.cumsum(dim=1) == 1).nonzero(as_tuple=True)
300 + torch.arange(nb_img_tokens - 1, device=input.device)[None, :],
302 ar_masks = input.new_zeros(input.size(), dtype=torch.int64)
305 torch.arange(self.vocabulary_size(), device=input.device) == t_nul
307 with torch.autograd.no_grad():
310 masked_inplace_autoregression(
320 input, loss_masks = self.trim((input, loss_masks))
322 return input, loss_masks
324 ######################
334 device=torch.device("cpu"),
338 def generate_descr(nb, cache_suffix, pruner):
339 return picoclvr.generate(
349 self.batch_size = batch_size
351 self.pruner_train = pruner_train
352 self.pruner_eval = pruner_eval
355 "nb_train_samples": nb_train_samples,
356 "nb_test_samples": nb_test_samples,
359 "nb_colors": nb_colors,
360 "batch_size": batch_size,
361 "rng_state": list(torch.get_rng_state()),
365 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
367 self.train_descr = generate_descr(
368 nb_train_samples, "train", pruner=self.pruner_train
370 self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
372 # Build the tokenizer
373 tokens = {"<nul>", "<img>"}
374 for d in [self.train_descr, self.test_descr]:
376 for t in s.strip().split(" "):
378 # make this set a sorted list to get the same tensors given
380 tokens = list(tokens)
382 self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
383 self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
385 # Tokenize the train and test sets
386 self.train_input = self.tensorize(self.train_descr)
387 self.test_input = self.tensorize(self.test_descr)
389 def batches(self, split="train"):
390 assert split in {"train", "test"}
391 input = self.train_input if split == "train" else self.test_input
392 for batch in tqdm.tqdm(
393 input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
395 yield self.trim(batch)
397 def vocabulary_size(self):
398 return len(self.token2id)
400 def compute_missing_properties(self, n_epoch, model, pruner=None):
401 acc_nb_requested_properties = []
402 acc_nb_missing_properties = []
405 for input in tqdm.tqdm(
406 self.test_input.split(self.batch_size),
408 desc=f"test-properties",
410 tape, loss_masks, _ = self.excise_last_image(input)
411 tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
412 result_descr = self.detensorize(tape)
413 np = picoclvr.nb_properties(
419 nb_requested_properties, _, nb_missing_properties = zip(*np)
420 acc_nb_requested_properties += nb_requested_properties
421 acc_nb_missing_properties += nb_missing_properties
422 acc_nb_results += len(result_descr)
424 nb_requested_properties = sum(acc_nb_requested_properties)
425 nb_missing_properties = sum(acc_nb_missing_properties)
427 prefix = "" if pruner is None else "pruned_"
428 log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
430 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
433 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
436 ######################################################################
438 def produce_results(self, n_epoch, model):
439 self.compute_missing_properties(n_epoch, model)
441 if self.pruner_eval is not None:
442 self.compute_missing_properties(n_epoch, model, self.pruner_eval)
444 nb_tokens_to_generate = self.height * self.width + 3
449 for primer_descr in [
450 "red above green <sep> green top <sep> blue right of red",
451 "there is red <sep> there is yellow <sep> there is blue",
452 "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
453 "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
455 primer += [primer_descr] * nb_per_primer
457 tape = self.tensorize(primer)
458 loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
459 tape, loss_masks = self.add_generated_image(tape, loss_masks, model)
460 result_descr = self.detensorize(tape)
462 np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
464 acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
465 acc_nb_results = len(result_descr)
467 nb_requested_properties = sum(acc_nb_requested_properties)
468 nb_missing_properties = sum(acc_nb_missing_properties)
471 log_string(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
473 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
476 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
479 img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
483 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
487 torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
493 image_name = os.path.join(args.result_dir, f"picoclvr_result_{n_epoch:04d}.png")
494 torchvision.utils.save_image(
495 img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=1.0
497 log_string(f"wrote {image_name}")
500 ######################################################################
503 class TaskMNIST(Task):
504 def __init__(self, batch_size, device=torch.device("cpu")):
506 self.batch_size = batch_size
508 def batches(self, split="train"):
509 assert split in {"train", "test"}
510 data_set = torchvision.datasets.MNIST(
511 root="./data", train=(split == "train"), download=True
513 data_input = data_set.data.view(-1, 28 * 28).long()
514 if args.nb_train_samples is not None:
515 data_input = data_input[: args.nb_train_samples]
516 for batch in tqdm.tqdm(
517 data_input.split(self.batch_size), desc=f"epoch-{split}"
521 def vocabulary_size(self):
524 def produce_results(self, n_epoch, model):
525 results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
526 ar_mask = torch.full_like(results, 1)
527 masked_inplace_autoregression(
528 model, self.batch_size, results, ar_mask, device=self.device
530 image_name = os.path.join(args.result_dir, f"mnist_result_{n_epoch:04d}.png")
531 torchvision.utils.save_image(
532 1 - results.reshape(-1, 1, 28, 28) / 255.0,
537 log_string(f"wrote {image_name}")
540 ######################################################################
545 class TaskMaze(Task):
546 def map2seq(self, *m):
547 return torch.cat([x.flatten(1) for x in m], 1)
549 def seq2map(self, s):
550 s = s.reshape(s.size(0), -1, self.height, self.width)
551 return (s[:, k] for k in range(s.size(1)))
561 device=torch.device("cpu"),
563 self.batch_size = batch_size
568 train_mazes, train_paths, _ = maze.create_maze_data(
573 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
575 self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
577 test_mazes, test_paths, _ = maze.create_maze_data(
582 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
584 self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
586 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
588 def batches(self, split="train", nb_to_use=-1, desc=None):
589 assert split in {"train", "test"}
590 input = self.train_input if split == "train" else self.test_input
592 input = input[:nb_to_use]
594 desc = f"epoch-{split}"
595 for batch in tqdm.tqdm(
596 input.split(self.batch_size), dynamic_ncols=True, desc=desc
600 def vocabulary_size(self):
603 def compute_error(self, model, split="train", nb_to_use=-1):
604 nb_total, nb_correct = 0, 0
605 for input in task.batches(split, nb_to_use):
606 result = input.clone()
607 ar_mask = result.new_zeros(result.size())
608 ar_mask[:, self.height * self.width :] = 1
609 result *= 1 - ar_mask
610 masked_inplace_autoregression(
611 model, self.batch_size, result, ar_mask, device=self.device
613 mazes, paths = self.seq2map(result)
614 nb_correct += maze.path_correctness(mazes, paths).long().sum()
615 nb_total += mazes.size(0)
617 return nb_total, nb_correct
619 def produce_results(self, n_epoch, model):
620 with torch.autograd.no_grad():
624 train_nb_total, train_nb_correct = self.compute_error(
625 model, "train", nb_to_use=1000
628 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
631 test_nb_total, test_nb_correct = self.compute_error(
632 model, "test", nb_to_use=1000
635 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
638 input = self.test_input[:48]
639 result = input.clone()
640 ar_mask = result.new_zeros(result.size())
641 ar_mask[:, self.height * self.width :] = 1
642 result *= 1 - ar_mask
643 masked_inplace_autoregression(
644 model, self.batch_size, result, ar_mask, device=self.device
647 mazes, paths = self.seq2map(input)
648 _, predicted_paths = self.seq2map(result)
650 filename = os.path.join(args.result_dir, f"maze_result_{n_epoch:04d}.png")
655 predicted_paths=predicted_paths,
656 path_correct=maze.path_correctness(mazes, predicted_paths),
658 log_string(f"wrote {filename}")
663 ######################################################################
666 def generate_snake_sequences(
667 nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
669 worlds = torch.randint(nb_colors, (nb, height, width), device=device)
670 nb_prior_visits = torch.zeros(nb, height, width, device=device)
673 snake_position = torch.cat(
675 torch.randint(height, (nb, 1), device=device),
676 torch.randint(width, (nb, 1), device=device),
680 snake_direction = torch.randint(4, (nb,), device=device)
681 sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
682 sequences_prior_visits = torch.zeros(
683 nb, 2 * length, device=device, dtype=torch.int64
685 i = torch.arange(nb, device=device) # [:,None]
687 for l in range(length):
689 snake_next_direction = torch.cat(
691 (snake_direction[:, None] - 1) % 4,
692 snake_direction[:, None],
693 (snake_direction[:, None] + 1) % 4,
699 vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
700 vw = snake_next_direction % 2 * (snake_next_direction - 2)
703 snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
704 snake_next_position = snake_position[:, None, :] + snake_next_speed
707 val = torch.logical_and(
709 snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
712 snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
716 # The multiplicative factors bias toward moving forward
719 * torch.tensor([[1.0, 2.0, 1.0]], device=device)
724 snake_direction = snake_next_direction[i, j]
726 sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
727 sequences_prior_visits[:, 2 * l] = nb_prior_visits[
728 i, snake_position[:, 0], snake_position[:, 1]
730 if l < prompt_length:
731 nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
732 sequences[:, 2 * l + 1] = snake_direction
735 snake_position = snake_next_position[i, j]
737 return sequences, sequences_prior_visits
740 # generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
744 def snake_solver(input, ar_mask):
745 for n in range(input.size(0)):
746 i, j, memory = 0, 0, {}
749 for l in range(input.size(1) // 2):
750 if ar_mask[n, 2 * l] == 1:
751 if memory.get((i, j)) is None:
754 input[n, 2 * l] = memory[(i, j)]
756 # print(f'@3 {memory=}')
757 if memory.get((i, j)) is None:
758 memory[(i, j)] = input[n, 2 * l]
760 assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
761 # print(f'@1 {i=} {j=}')
762 d = input[n, 2 * l + 1].item()
763 i += (d + 1) % 2 * (d - 1)
765 # print(f'@2 {i=} {j=}')
768 class TaskSnake(Task):
779 device=torch.device("cpu"),
781 self.batch_size = batch_size
785 self.prompt_length = prompt_length
787 self.train_input, self.train_prior_visits = generate_snake_sequences(
796 self.test_input, self.test_prior_visits = generate_snake_sequences(
806 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
808 def batches(self, split="train", nb_to_use=-1, desc=None):
809 assert split in {"train", "test"}
810 input = self.train_input if split == "train" else self.test_input
812 input = input[:nb_to_use]
814 desc = f"epoch-{split}"
815 for batch in tqdm.tqdm(
816 input.split(self.batch_size), dynamic_ncols=True, desc=desc
820 def vocabulary_size(self):
823 def produce_results(self, n_epoch, model):
824 with torch.autograd.no_grad():
828 def compute_nb_correct(input, prior_visits):
829 result = input.clone()
830 i = torch.arange(result.size(1), device=result.device)[None, :]
832 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
836 result *= 1 - ar_mask
838 # snake_solver(result,ar_mask)
840 masked_inplace_autoregression(
841 model, self.batch_size, result, ar_mask, device=self.device
844 nb_total = ((prior_visits > 0) * ar_mask).sum()
847 (result == input).long() * (prior_visits > 0) * ar_mask
850 # nb_total = result.size(0)
851 # nb_correct = ((result - input).abs().sum(1) == 0).sum()
853 return nb_total, nb_correct
855 # train_nb_total, train_nb_correct = compute_nb_correct(
856 # self.train_input, self.train_prior_visits
860 # f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
863 test_nb_total, test_nb_correct = compute_nb_correct(
864 self.test_input[:1000], self.test_prior_visits[:1000]
868 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
874 ######################################################################
877 def picoclvr_pruner_horizontal_green(p):
878 return not ("green" in p and ("left" in p or "right" in p))
881 picoclvr_pruner_train = (
882 picoclvr_pruner_horizontal_green
883 if args.picocvlr_prune_properties in {"train+eval"}
887 picoclvr_pruner_eval = (
888 (lambda p: not picoclvr_pruner_horizontal_green(p))
889 if args.picocvlr_prune_properties in {"train+eval", "eval"}
893 ######################################################################
895 if args.task == "picoclvr":
897 nb_train_samples=args.nb_train_samples,
898 nb_test_samples=args.nb_test_samples,
899 batch_size=args.batch_size,
900 height=args.picoclvr_height,
901 width=args.picoclvr_width,
902 nb_colors=args.picoclvr_nb_colors,
904 pruner_train=picoclvr_pruner_train,
905 pruner_eval=picoclvr_pruner_eval,
908 elif args.task == "mnist":
910 batch_size=args.batch_size,
914 elif args.task == "maze":
916 nb_train_samples=args.nb_train_samples,
917 nb_test_samples=args.nb_test_samples,
918 batch_size=args.batch_size,
919 height=args.maze_height,
920 width=args.maze_width,
921 nb_walls=args.maze_nb_walls,
925 elif args.task == "snake":
927 nb_train_samples=args.nb_train_samples,
928 nb_test_samples=args.nb_test_samples,
929 batch_size=args.batch_size,
930 height=args.snake_height,
931 width=args.snake_width,
932 nb_colors=args.snake_nb_colors,
933 length=args.snake_length,
934 prompt_length=args.snake_length // 2,
939 raise ValueError(f"Unknown task {args.task}")
941 ######################################################################
943 log_string(f"device {device}")
945 vocabulary_size = task.vocabulary_size()
947 log_string(f"vocabulary_size {vocabulary_size}")
949 ##############################
952 vocabulary_size=vocabulary_size,
953 dim_model=args.dim_model,
954 dim_keys=args.dim_keys,
955 dim_hidden=args.dim_hidden,
956 nb_heads=args.nb_heads,
957 nb_blocks=args.nb_blocks,
959 dropout=args.dropout,
964 nb_parameters = sum(p.numel() for p in model.parameters())
965 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
967 ######################################################################
969 nb_epochs_finished = 0
971 if args.no_checkpoint:
972 log_string(f"not trying to load checkpoint.")
976 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
977 checkpoint = torch.load(checkpoint_name)
978 nb_epochs_finished = checkpoint["nb_epochs_finished"]
979 model.load_state_dict(checkpoint["model_state"])
980 torch.set_rng_state(checkpoint["rng_state"])
981 if torch.cuda.is_available():
982 torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
984 log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
986 except FileNotFoundError:
987 log_string("starting from scratch.")
990 log_string("error when loading the checkpoint.")
993 ######################################################################
995 nb_epochs = args.nb_epochs if args.nb_epochs > 0 else nb_epochs_default
998 for input in task.batches(split="train"):
999 token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
1000 token_probas = token_count / token_count.sum()
1001 entropy = -torch.xlogy(token_probas, token_probas).sum()
1002 train_set_perplexity = math.exp(entropy)
1004 ##############################
1006 if args.learning_rate_schedule == "cos":
1007 learning_rate_schedule = {}
1008 for n_epoch in range(args.nb_epochs):
1009 u = n_epoch / args.nb_epochs * math.pi
1010 learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
1015 tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
1019 learning_rate_schedule = {}
1020 learning_rate = args.learning_rate
1021 for n_epoch in range(args.nb_epochs):
1023 learning_rate = u[n_epoch]
1024 learning_rate_schedule[n_epoch] = learning_rate
1026 log_string(f"learning_rate_schedule {learning_rate_schedule}")
1028 ##############################
1032 if nb_epochs_finished >= nb_epochs:
1033 task.produce_results(nb_epochs_finished, model)
1035 for n_epoch in range(nb_epochs_finished, nb_epochs):
1036 learning_rate = learning_rate_schedule[n_epoch]
1038 log_string(f"learning_rate {learning_rate}")
1040 if args.optim == "sgd":
1041 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
1042 elif args.optim == "adam":
1043 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1044 elif args.optim == "adamw":
1045 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
1047 raise ValueError(f"Unknown optimizer {args.optim}.")
1051 nb_train_samples, acc_train_loss = 0, 0.0
1053 for input in task.batches(split="train"):
1054 input = input.to(device)
1055 output = model(mygpt.BracketedSequence(input)).x
1056 loss = F.cross_entropy(output.transpose(1, 2), input)
1057 acc_train_loss += loss.item() * input.size(0)
1058 nb_train_samples += input.size(0)
1059 nb_samples_seen += input.size(0)
1061 optimizer.zero_grad()
1065 with torch.autograd.no_grad():
1068 nb_test_samples, acc_test_loss = 0, 0.0
1070 for input in task.batches(split="test"):
1071 input = input.to(device)
1073 # input, loss_masks, true_images = task.excise_last_image(input)
1074 # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
1076 output = model(mygpt.BracketedSequence(input)).x
1077 loss = F.cross_entropy(output.transpose(1, 2), input)
1078 acc_test_loss += loss.item() * input.size(0)
1079 nb_test_samples += input.size(0)
1081 train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
1082 test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
1085 f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
1088 task.produce_results(n_epoch, model)
1091 "nb_epochs_finished": n_epoch + 1,
1092 "model_state": model.state_dict(),
1093 "rng_state": torch.get_rng_state(),
1096 if torch.cuda.is_available():
1097 checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
1099 checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
1100 torch.save(checkpoint, checkpoint_name)
1101 log_string(f"saved checkpoint {checkpoint_name}")
1103 ######################################################################