Added default configurations and reformated with black.
[mygpt.git] / main.py
diff --git a/main.py b/main.py
index a18beb1..f7d03cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -15,200 +15,325 @@ import mygpt
 
 ######################################################################
 
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 ######################################################################
+parser = argparse.ArgumentParser(description="My own GPT.")
 
-parser = argparse.ArgumentParser(description = 'My own GPT.')
+parser.add_argument("--log_filename", type=str, default="train.log")
 
-parser.add_argument('--log_filename',
-                    type = str, default = 'train.log')
+parser.add_argument("--seed", type=int, default=0)
 
-parser.add_argument('--download',
-                    action='store_true', default = False)
+parser.add_argument("--nb_epochs", type=int, default=None)
 
-parser.add_argument('--seed',
-                    type = int, default = 0)
+parser.add_argument("--batch_size", type=int, default=25)
 
-parser.add_argument('--nb_epochs',
-                    type = int, default = 100)
+parser.add_argument("--data", type=str, default="wiki103")
 
-parser.add_argument('--batch_size',
-                    type = int, default = 25)
+parser.add_argument("--data_size", type=int, default=None)
 
-parser.add_argument('--data',
-                    type = str, default = 'wiki103')
+parser.add_argument("--optim", type=str, default="adam")
 
-parser.add_argument('--data_size',
-                    type = int, default = -1)
+parser.add_argument("--learning_rate", type=float, default=1e-3)
 
-parser.add_argument('--optim',
-                    type = str, default = 'adam')
+parser.add_argument("--learning_rate_end", type=float, default=1e-6)
 
-parser.add_argument('--learning_rate',
-                    type = float, default = 1e-4)
+parser.add_argument("--dim_model", type=int, default=None)
 
-parser.add_argument('--dim_model',
-                    type = int, default = 512)
+parser.add_argument("--dim_keys", type=int, default=None)
 
-parser.add_argument('--dim_keys',
-                    type = int, default = 64)
+parser.add_argument("--dim_hidden", type=int, default=None)
 
-parser.add_argument('--dim_hidden',
-                    type = int, default = 2048)
+parser.add_argument("--nb_heads", type=int, default=None)
 
-parser.add_argument('--nb_heads',
-                    type = int, default = 8)
+parser.add_argument("--nb_blocks", type=int, default=None)
 
-parser.add_argument('--nb_blocks',
-                    type = int, default = 12)
+parser.add_argument("--dropout", type=float, default=0.1)
 
-parser.add_argument('--dropout',
-                    type = float, default = 0.1)
+parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-parser.add_argument('--synthesis_sampling',
-                    action='store_true', default = True)
+parser.add_argument("--no_checkpoint", action="store_true", default=False)
 
-parser.add_argument('--checkpoint_name',
-                    type = str, default = 'checkpoint.pth')
+parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
 
-parser.add_argument('--picoclvr_many_colors',
-                    action='store_true', default = False)
+##############################
+# picoclvr options
+
+parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
+
+parser.add_argument("--picoclvr_height", type=int, default=12)
+
+parser.add_argument("--picoclvr_width", type=int, default=16)
 
 ######################################################################
 
 args = parser.parse_args()
 
-log_file = open(args.log_filename, 'w')
+log_file = open(args.log_filename, "w")
 
 if args.seed >= 0:
     torch.manual_seed(args.seed)
 
 ######################################################################
 
+
 def log_string(s):
-    t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime())
+    t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
 
     if log_file is not None:
-        log_file.write(t + s + '\n')
+        log_file.write(t + s + "\n")
         log_file.flush()
 
     print(t + s)
     sys.stdout.flush()
 
+
 for n in vars(args):
-    log_string(f'args.{n} {getattr(args, n)}')
+    log_string(f"args.{n} {getattr(args, n)}")
+
+######################################################################
+
+default_args = {
+    "mnist": {
+        "nb_epochs": 10,
+        "dim_model": 64,
+        "dim_keys": 64,
+        "dim_hidden": 128,
+        "nb_heads": 4,
+        "nb_blocks": 6,
+    },
+    "mnist-debug": {
+        "nb_epochs": 2,
+        "data_size": 10000,
+        "dim_model": 8,
+        "dim_keys": 8,
+        "dim_hidden": 8,
+        "nb_heads": 2,
+        "nb_blocks": 4,
+    },
+    "wiki103": {
+        "nb_epochs": 25,
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_blocks": 12,
+    },
+    "picoclvr": {
+        "nb_epochs": 25,
+        "dim_model": 512,
+        "dim_keys": 64,
+        "dim_hidden": 2048,
+        "nb_heads": 8,
+        "nb_blocks": 12,
+    },
+}
+
+if args.data in default_args:
+    for k, v in default_args[args.data].items():
+        if getattr(args, k) is None:
+            setattr(args, k, v)
+
+######################################################################
+
+
+def autoregression(
+    model,
+    batch_size,
+    nb_samples,
+    nb_tokens_to_generate,
+    primer=None,
+    device=torch.device("cpu"),
+):
+    results = torch.zeros(
+        nb_samples, nb_tokens_to_generate, dtype=torch.int64, device=device
+    )
+
+    if primer is None:
+        first = 0
+    else:
+        first = primer.size(1)
+        results = torch.cat((primer, results), 1)
+
+    for input in results.split(batch_size):
+        for s in range(first, input.size(1)):
+            output = model(input)
+            logits = output[:, s]
+            if args.deterministic_synthesis:
+                t_next = logits.argmax(1)
+            else:
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                t_next = dist.sample()
+            input[:, s] = t_next
+
+    return results
+
 
 ######################################################################
 
+
 class Task:
-    def batches(self, split = 'train'):
+    def batches(self, split="train"):
         pass
 
     def vocabulary_size(self):
         pass
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
+    def produce_results(self, n_epoch, model):
         pass
 
+
 ######################################################################
 
 import picoclvr
 
-class TaskPicoCLVR(Task):
 
-    def __init__(self, batch_size,
-                 height = 6, width = 8, many_colors = False,
-                 device = torch.device('cpu')):
+class TaskPicoCLVR(Task):
 
+    # Make a tensor from a list of strings
+    def tensorize(self, descr):
+        token_descr = [s.strip().split(" ") for s in descr]
+        l = max([len(s) for s in token_descr])
+        padded_token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+        id_descr = [[self.token2id[u] for u in s] for s in padded_token_descr]
+        return torch.tensor(id_descr, device=self.device)
+
+    def trim(self, x, token="<nul>"):
+        n = self.token2id[token]
+        i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+        a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+        return x[:, a:b]
+
+    def __init__(
+        self, batch_size, height, width, nb_colors=5, device=torch.device("cpu")
+    ):
+        def generate_descr(nb):
+            return picoclvr.generate(
+                nb, height=self.height, width=self.width, nb_colors=nb_colors
+            )
+
+        self.height = height
+        self.width = width
         self.batch_size = batch_size
         self.device = device
-        nb = args.data_size if args.data_size > 0 else 250000
-
-        descr = picoclvr.generate(
-            nb,
-            height = height, width = width,
-            many_colors = many_colors
-        )
+        nb = args.data_size if args.data_size is not None else 250000
+
+        log_string(f"generating {nb} samples (can take some time)")
+        self.train_descr = generate_descr((nb * 4) // 5)
+        self.test_descr = generate_descr((nb * 1) // 5)
+
+        # Build the tokenizer
+        tokens = {"<nul>"}
+        for d in [self.train_descr, self.test_descr]:
+            for s in d:
+                for t in s.strip().split(" "):
+                    tokens.add(t)
+        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+
+        # Tokenize the train and test sets
+        self.train_input = self.tensorize(self.train_descr)
+        self.test_input = self.tensorize(self.test_descr)
+
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
+        input = self.train_input if split == "train" else self.test_input
+        for batch in tqdm.tqdm(input.split(self.batch_size), desc=f"epoch-{split}"):
+            yield self.trim(batch)
 
-        self.test_descr = descr[:nb // 5]
-        self.train_descr = descr[nb // 5:]
-
-        descr = [ s.strip().split(' ') for s in descr ]
-        l = max([ len(s) for s in descr ])
-        descr = [ s + [ '<unk>' ] * (l - len(s)) for s in descr ]
+    def vocabulary_size(self):
+        return len(self.token2id)
 
-        tokens = set()
-        for s in descr:
-            for t in s: tokens.add(t)
-        self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ])
-        self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ])
+    def test_model(
+        self, n_epoch, model, primers_descr, nb_per_primer=1, generate_images=False
+    ):
+        nb_tokens_to_generate = self.height * self.width + 3
+        result_descr = []
 
-        t = [ [ self.token2id[u] for u in s ] for s in descr ]
-        data_input = torch.tensor(t, device = self.device)
+        for primer_descr in primers_descr:
 
-        self.test_input = data_input[:nb // 5]
-        self.train_input = data_input[nb // 5:]
+            results = autoregression(
+                model,
+                self.batch_size,
+                nb_samples=nb_per_primer,
+                nb_tokens_to_generate=nb_tokens_to_generate,
+                primer=self.tensorize([primer_descr]).expand(nb_per_primer, -1),
+                device=self.device,
+            )
 
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
-        if split == 'train':
-            for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = f'epoch-{split}'):
-                yield batch
-        else:
-            for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = f'epoch-{split}'):
-                yield batch
+            l = [" ".join([self.id2token[t.item()] for t in r]) for r in results]
+            result_descr += l
 
-    def vocabulary_size(self):
-        return len(self.token2id)
+        np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
 
-    def generate(self, primer, model, nb_tokens):
-        t_primer = primer.strip().split(' ')
-        t_generated = [ ]
+        nb_requested_properties, _, nb_missing_properties = zip(*np)
 
-        for j in range(nb_tokens):
-            t = [ [ self.token2id[u] for u in t_primer + t_generated ] ]
-            input = torch.tensor(t, device = self.device)
-            output = model(input)
-            logits = output[0, -1]
-            if args.synthesis_sampling:
-                dist = torch.distributions.categorical.Categorical(logits = logits)
-                t = dist.sample()
-            else:
-                t = logits.argmax()
-            t_generated.append(self.id2token[t.item()])
+        log_string(
+            f"nb_requested_properties {sum(nb_requested_properties) / len(result_descr):.02f} nb_missing_properties {sum(nb_missing_properties) / len(result_descr):.02f}"
+        )
 
-        return ' '.join(t_primer + t_generated)
+        np = torch.tensor(np)
+        count = torch.empty(np[:, 0].max() + 1, np[:, 2].max() + 1, dtype=torch.int64)
+        for i in range(count.size(0)):
+            for j in range(count.size(1)):
+                count[i, j] = ((np[:, 0] == i).long() * (np[:, 2] == j).long()).sum()
+
+        if generate_images:
+            img = [
+                picoclvr.descr2img(d, height=self.height, width=self.width)
+                for d in result_descr
+            ]
+
+            img = torch.cat(img, 0)
+            image_name = f"result_picoclvr_{n_epoch:04d}.png"
+            torchvision.utils.save_image(
+                img / 255.0, image_name, nrow=nb_per_primer, pad_value=0.8
+            )
+            log_string(f"wrote {image_name}")
+
+        return count
+
+    def produce_results(self, n_epoch, model):
+        primers_descr = [
+            "red above green <sep> green top <sep> blue right of red <img>",
+            "there is red <sep> there is yellow <sep> there is blue <img>",
+            "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>",
+            "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>",
+        ]
+
+        self.test_model(
+            n_epoch, model, primers_descr, nb_per_primer=8, generate_images=True
+        )
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        descr = [ ]
-        nb_per_primer = 8
+        # FAR TOO SLOW!!!
 
-        for primer in [
-                'red above green <sep> green top <sep> blue right of red <img>',
-                'there is red <sep> there is yellow <sep> there is blue <img>',
-                'red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left <img>',
-                'green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top <img>',
-        ]:
+        # test_primers_descr=[ s.split('<img>')[0] for s in self.test_descr ]
 
-            for k in range(nb_per_primer):
-                descr.append(self.generate(primer, model, nb_tokens))
+        # count=self.test_model(
+        # n_epoch, model,
+        # test_primers_descr,
+        # nb_per_primer=1, generate_images=False
+        # )
 
-        img = [ picoclvr.descr2img(d) for d in descr ]
-        img = torch.cat(img, 0)
-        file_name = f'result_picoclvr_{n_epoch:04d}.png'
-        torchvision.utils.save_image(img / 255.,
-                                     file_name, nrow = nb_per_primer, pad_value = 0.8)
-        log_string(f'wrote {file_name}')
+        # with open(f'perf_{n_epoch:04d}.txt', 'w') as f:
+        # for i in range(count.size(0)):
+        # for j in range(count.size(1)):
+        # f.write(f'{count[i,j]}')
+        # f.write(" " if j<count.size(1)-1 else "\n")
 
-        log_string(f'nb_misssing {picoclvr.nb_missing_properties(descr)}')
 
 ######################################################################
 
-class TaskWiki103(Task):
 
-    def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100,
-                 device = torch.device('cpu')):
+class TaskWiki103(Task):
+    def __init__(
+        self,
+        batch_size,
+        len_min=10,
+        len_max=200,
+        min_freq=100,
+        device=torch.device("cpu"),
+    ):
 
         self.batch_size = batch_size
         self.len_min = len_min
@@ -216,216 +341,239 @@ class TaskWiki103(Task):
         self.min_freq = min_freq
         self.device = device
 
-        self.tokenizer = torchtext.data.get_tokenizer('basic_english')
-        train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/')
+        self.tokenizer = torchtext.data.get_tokenizer("basic_english")
+        train_iter = torchtext.datasets.WikiText103(split="train", root="./data/nlp/")
 
         # Mostly for debug
-        if args.data_size > 0:
+        if args.data_size is not None:
             train_iter = itertools.islice(train_iter, args.data_size)
 
         def yield_tokens():
-            for l in tqdm.tqdm(train_iter, desc = 'vocab'):
+            for l in tqdm.tqdm(train_iter, desc="vocab"):
                 yield self.tokenizer(l)
 
         self.vocab = torchtext.vocab.build_vocab_from_iterator(
-            yield_tokens(),
-            specials = [ '<unk>', '<non>' ],
-            min_freq = self.min_freq
+            yield_tokens(), specials=["<unk>", "<nul>"], min_freq=self.min_freq
         )
 
-        self.vocab.set_default_index(self.vocab[ '<unk>' ])
+        self.vocab.set_default_index(self.vocab["<unk>"])
 
+    # makes a tensor from a list of list of tokens
     def tensorize(self, s):
         a = max(len(x) for x in s)
-        return torch.tensor([ self.vocab(x + [ '<non>' ] * (a - len(x))) for x in s ])
+        return torch.tensor([self.vocab(x + ["<nul>"] * (a - len(x))) for x in s])
 
     def yield_batches(self, ds):
-        s = [ ]
+        s = []
         for l in ds:
             q = self.tokenizer(l)
             if len(q) >= self.len_min and len(q) <= self.len_max:
-                s += [ q ]
+                s += [q]
                 if len(s) == self.batch_size:
                     yield self.tensorize(s)
-                    s = [ ]
+                    s = []
 
         if len(s) > 0:
             yield self.tensorize(s)
 
-    def batches(self, split = 'train'):
-        data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/')
+    def batches(self, split="train"):
+        data_iter = torchtext.datasets.WikiText103(split=split, root="./data/nlp/")
 
         # Mostly for debug
-        if args.data_size > 0:
+        if args.data_size is not None:
             data_iter = itertools.islice(data_iter, args.data_size)
 
-        return self.yield_batches(tqdm.tqdm(data_iter, desc = f'epoch-{split}'))
+        return self.yield_batches(tqdm.tqdm(data_iter, desc=f"epoch-{split}"))
 
     def vocabulary_size(self):
         return len(self.vocab)
 
-    def produce_results(self, n_epoch, model, nb_tokens = 50):
-        file_name = f'result_wiki103_{n_epoch:04d}.txt'
-
-        with open(file_name, 'w') as outfile:
-             for primer in [
-                     'the cat is hunting a',
-                     'paris is the capital',
-                     'cars are convenient',
-                     'the difference between men and women is',
-                     'the object was blue all over and green all over it was',
-                     'cherries are red and lemons are',
-                     'cherries are sweet and lemons are',
-                     'two plus three equals',
-                     'deep learning is',
-             ]:
-                 t_primer = self.tokenizer(primer)
-                 t_generated = [ ]
-
-                 for j in range(nb_tokens):
-
-                     input = self.tensorize([ t_primer + t_generated ]).to(self.device)
-                     output = model(input)
-                     logits = output[0, -1]
-                     if args.synthesis_sampling:
-                         dist = torch.distributions.categorical.Categorical(logits = logits)
-                         t = dist.sample()
-                     else:
-                         t = logits.argmax()
-                     t_generated.append(self.vocab.lookup_token(t))
-                     if t_generated[-1] == '<non>': break
-
-                 s = ' '.join(t_generated)
-
-                 outfile.write(f'<{primer}> {s}\n')
-
-        log_string(f'wrote {file_name}')
+    def produce_results(self, n_epoch, model):
+        nb_tokens = 50
+        file_name = f"result_wiki103_{n_epoch:04d}.txt"
+
+        with open(file_name, "w") as outfile:
+            for primer in [
+                "the cat is hunting a",
+                "paris is the capital",
+                "cars are convenient",
+                "the difference between men and women is",
+                "the object was blue all over and green all over it was",
+                "cherries are red and lemons are",
+                "cherries are sweet and lemons are",
+                "two plus three equals",
+                "deep learning is",
+            ]:
+                t_primer = self.tokenizer(primer)
+                t_generated = []
+
+                for j in range(nb_tokens):
+
+                    input = self.tensorize([t_primer + t_generated]).to(self.device)
+                    input = F.pad(
+                        input, (0, 1)
+                    )  # Add the next token, the one to predict
+                    output = model(input)
+                    logits = output[0, -1]
+                    if args.deterministic_synthesis:
+                        t_next = logits.argmax()
+                    else:
+                        dist = torch.distributions.categorical.Categorical(
+                            logits=logits
+                        )
+                        t_next = dist.sample()
+                    t_generated.append(self.vocab.lookup_token(t_next))
+                    if t_generated[-1] == "<nul>":
+                        break
+
+                s = " ".join(t_generated)
+
+                outfile.write(f"<{primer}> {s}\n")
+
+        log_string(f"wrote {file_name}")
+
 
 ######################################################################
 
-class TaskMNIST(Task):
 
-    def __init__(self, batch_size, device = torch.device('cpu')):
+class TaskMNIST(Task):
+    def __init__(self, batch_size, device=torch.device("cpu")):
         self.device = device
         self.batch_size = batch_size
 
-    def batches(self, split = 'train'):
-        assert split in { 'train', 'test' }
+    def batches(self, split="train"):
+        assert split in {"train", "test"}
         data_set = torchvision.datasets.MNIST(
-            root = './data', train = (split == 'train'),
-            download = True
+            root="./data", train=(split == "train"), download=True
         )
         data_input = data_set.data.view(-1, 28 * 28).long()
-        if args.data_size >= 0:
-            data_input = data_input[:args.data_size]
-        for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = f'epoch-{split}'):
+        if args.data_size is not None:
+            data_input = data_input[: args.data_size]
+        for batch in tqdm.tqdm(
+            data_input.split(self.batch_size), desc=f"epoch-{split}"
+        ):
             yield batch
 
     def vocabulary_size(self):
         return 256
 
-    def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device)
-        for input in results.split(self.batch_size):
-            for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'):
-                output = model(input)
-                logits = output[:, s]
-                if args.synthesis_sampling:
-                    dist = torch.distributions.categorical.Categorical(logits = logits)
-                    t = dist.sample()
-                else:
-                    t = logits.argmax(1)
-                input[:, s + 1] = t
-
-        image_name = f'result_mnist_{n_epoch:04d}.png'
-        torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
-                                     image_name, nrow = 16, pad_value = 0.8)
-        log_string(f'wrote {image_name}')
-
-######################################################################
+    def produce_results(self, n_epoch, model):
+        nb_samples = 64
+        results = autoregression(
+            model, self.batch_size, nb_samples, 28 * 28, device=self.device
+        )
+        image_name = f"result_mnist_{n_epoch:04d}.png"
+        torchvision.utils.save_image(
+            1 - results.reshape(-1, 1, 28, 28) / 255.0,
+            image_name,
+            nrow=16,
+            pad_value=0.8,
+        )
+        log_string(f"wrote {image_name}")
 
-def check_causality(model):
-    #m = model[1:]
-    input = torch.rand(1, 5, dim_model).requires_grad_()
-    output = m(input)
-    a = torch.zeros(output.size(1), input.size(1))
-    for k in range(output.size(1)):
-        for d in range(output.size(2)):
-            g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True)
-            a[k] += g.squeeze(0).pow(2).sum(1)
-    print(a)
 
 ######################################################################
 
-log_string(f'device {device}')
-
-if args.data == 'wiki103':
-    task = TaskWiki103(batch_size = args.batch_size, device = device)
-elif args.data == 'mnist':
-    task = TaskMNIST(batch_size = args.batch_size, device = device)
-elif args.data == 'picoclvr':
-    task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device)
+log_string(f"device {device}")
+
+if args.data == "wiki103":
+    task = TaskWiki103(batch_size=args.batch_size, device=device)
+elif args.data in {"mnist", "mnist-debug"}:
+    task = TaskMNIST(batch_size=args.batch_size, device=device)
+elif args.data == "picoclvr":
+    task = TaskPicoCLVR(
+        batch_size=args.batch_size,
+        height=args.picoclvr_height,
+        width=args.picoclvr_width,
+        nb_colors=args.picoclvr_nb_colors,
+        device=device,
+    )
 else:
-    raise ValueError(f'Unknown dataset {args.data}.')
+    raise ValueError(f"Unknown dataset {args.data}.")
 
 vocabulary_size = task.vocabulary_size()
 
-log_string(f'vocabulary_size {vocabulary_size}')
+log_string(f"vocabulary_size {vocabulary_size}")
 
 ##############################
 
 model = mygpt.MyGPT(
-    vocabulary_size = vocabulary_size,
-    dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden,
-    nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout
+    vocabulary_size=vocabulary_size,
+    dim_model=args.dim_model,
+    dim_keys=args.dim_keys,
+    dim_hidden=args.dim_hidden,
+    nb_heads=args.nb_heads,
+    nb_blocks=args.nb_blocks,
+    dropout=args.dropout,
 )
 
 model.to(device)
 
 nb_parameters = sum(p.numel() for p in model.parameters())
-log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)')
-
-######################################################################
-
-if args.optim == 'sgd':
-    optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adam':
-    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate)
-elif args.optim == 'adamw':
-    optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate)
-else:
-    raise ValueError(f'Unknown optimizer {args.optim}.')
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
 nb_epochs_finished = 0
 
-try:
-    checkpoint = torch.load(args.checkpoint_name, map_location = device)
-    nb_epochs_finished = checkpoint['nb_epochs_finished']
-    model.load_state_dict(checkpoint['model_state'])
-    optimizer.load_state_dict(checkpoint['optimizer_state'])
-    print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.')
+if args.no_checkpoint:
+    log_string(f"not trying to load checkpoint.")
 
-except FileNotFoundError:
-    print('Starting from scratch.')
-
-except:
-    print('Error when loading the checkpoint.')
-    exit(1)
+else:
+    try:
+        checkpoint = torch.load(args.checkpoint_name)
+        nb_epochs_finished = checkpoint["nb_epochs_finished"]
+        model.load_state_dict(checkpoint["model_state"])
+        torch.set_rng_state(checkpoint["rng_state"])
+        if torch.cuda.is_available():
+            torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
+        log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
+
+    except FileNotFoundError:
+        log_string("starting from scratch.")
+
+    except:
+        log_string("error when loading the checkpoint.")
+        exit(1)
 
 ######################################################################
 
-for k in range(nb_epochs_finished, args.nb_epochs):
+token_count = 0
+for input in task.batches(split="train"):
+    token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+for n_epoch in range(nb_epochs_finished, args.nb_epochs):
+
+    if args.learning_rate_end < 0:
+        lr = args.learning_rate
+    else:
+        u = n_epoch / (args.nb_epochs - 1)
+        lr = math.exp(
+            (1 - u) * math.log(args.learning_rate)
+            + u * math.log(args.learning_rate_end)
+        )
+        log_string(f"learning_rate {lr}")
+
+    if args.optim == "sgd":
+        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+    elif args.optim == "adam":
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+    elif args.optim == "adamw":
+        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+    else:
+        raise ValueError(f"Unknown optimizer {args.optim}.")
 
     model.train()
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input in task.batches(split = 'train'):
+    for input in task.batches(split="train"):
         input = input.to(device)
         output = model(input)
-        loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+        loss = F.cross_entropy(output.transpose(1, 2), input)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
 
@@ -439,26 +587,31 @@ for k in range(nb_epochs_finished, args.nb_epochs):
 
         nb_test_samples, acc_test_loss = 0, 0.0
 
-        for input in task.batches(split = 'test'):
+        for input in task.batches(split="test"):
             input = input.to(device)
             output = model(input)
-            loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:])
+            loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
 
-        train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples))
-        test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples))
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+        test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
 
-        log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}')
+        log_string(
+            f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
+        )
 
-        task.produce_results(k, model)
+        task.produce_results(n_epoch, model)
 
     checkpoint = {
-        'nb_epochs_finished': k + 1,
-        'model_state': model.state_dict(),
-        'optimizer_state': optimizer.state_dict()
+        "nb_epochs_finished": n_epoch + 1,
+        "model_state": model.state_dict(),
+        "rng_state": torch.get_rng_state(),
     }
 
+    if torch.cuda.is_available():
+        checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
+
     torch.save(checkpoint, args.checkpoint_name)
 
 ######################################################################