78005274f3a9c9d73a52d96fbf207a3540039461
[beaver.git] / beaver.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, itertools, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
30
31 parser.add_argument("--log_filename", type=str, default="train.log")
32
33 parser.add_argument("--result_dir", type=str, default="results_default")
34
35 parser.add_argument("--seed", type=int, default=0)
36
37 parser.add_argument("--nb_epochs", type=int, default=25)
38
39 parser.add_argument("--nb_train_samples", type=int, default=200000)
40
41 parser.add_argument("--nb_test_samples", type=int, default=50000)
42
43 parser.add_argument("--batch_size", type=int, default=25)
44
45 parser.add_argument("--optim", type=str, default="adam")
46
47 parser.add_argument("--learning_rate", type=float, default=1e-3)
48
49 parser.add_argument(
50     "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
51 )
52
53 parser.add_argument("--dim_model", type=int, default=512)
54
55 parser.add_argument("--dim_keys", type=int, default=64)
56
57 parser.add_argument("--dim_hidden", type=int, default=2048)
58
59 parser.add_argument("--nb_heads", type=int, default=8)
60
61 parser.add_argument("--nb_blocks", type=int, default=12)
62
63 parser.add_argument("--dropout", type=float, default=0.1)
64
65 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
66
67 parser.add_argument("--random_regression_order", action="store_true", default=False)
68
69 parser.add_argument("--noncausal_prompt", action="store_true", default=False)
70
71 parser.add_argument("--no_checkpoint", action="store_true", default=False)
72
73 parser.add_argument("--overwrite_results", action="store_true", default=False)
74
75 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
76
77 ##############################
78 # maze options
79
80 parser.add_argument("--maze_height", type=int, default=13)
81
82 parser.add_argument("--maze_width", type=int, default=21)
83
84 parser.add_argument("--maze_nb_walls", type=int, default=15)
85
86 ##############################
87 # one-shot prediction
88
89 parser.add_argument("--oneshot", action="store_true", default=False)
90
91 parser.add_argument("--oneshot_input", type=str, default="head")
92
93 parser.add_argument("--oneshot_output", type=str, default="trace")
94
95 ######################################################################
96
97 args = parser.parse_args()
98
99 try:
100     os.mkdir(args.result_dir)
101 except FileExistsError:
102     if not args.overwrite_results:
103         print(f"result directory {args.result_dir} already exists")
104         exit(1)
105
106 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
107
108 if args.seed >= 0:
109     # torch.backends.cudnn.deterministic = True
110     # torch.backends.cudnn.benchmark = False
111     # torch.use_deterministic_algorithms(True)
112     torch.manual_seed(args.seed)
113     if torch.cuda.is_available():
114         torch.cuda.manual_seed_all(args.seed)
115
116 ######################################################################
117
118
119 def log_string(s):
120     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
121
122     if log_file is not None:
123         log_file.write(t + s + "\n")
124         log_file.flush()
125
126     print(t + s)
127     sys.stdout.flush()
128
129
130 for n in vars(args):
131     log_string(f"args.{n} {getattr(args, n)}")
132
133 ######################################################################
134
135
136 def reorder(x, order, reverse=False):  # x is NxTxD1x...xDk, order is NxT'
137     u = x.reshape(x.size()[:2] + (-1,))
138     order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
139     if reverse:
140         v = u.new(u.size()).scatter_(1, order, u)
141     else:
142         v = u.gather(1, order)
143     v = v.reshape(v.size()[:2] + x.size()[2:])
144     return v
145
146
147 def shuffle(x, prompt_len):
148     if args.random_regression_order:
149         order = torch.rand(x.size(), device=x.device)
150         order[:, :prompt_len] = torch.arange(-prompt_len, 0, device=x.device)
151         order = order.sort(1).indices
152     else:
153         order = (
154             torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
155         )
156     return reorder(x, order), order
157
158
159 def eval_mygpt(model, input, mode="standard", prompt_len=0):
160     x, order = shuffle(input, prompt_len)
161     x = model(mygpt.BracketedSequence(x), mode=mode, order=order).x
162     return reorder(x, order, reverse=True)
163
164
165 ######################################################################
166
167 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
168 # tokens that should be generated
169
170
171 def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
172     for input, ar_mask, order in zip(
173         input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size)
174     ):
175         i = (ar_mask.sum(0) > 0).nonzero()
176         if i.min() > 0:
177             # Needed to initialize the model's cache
178             model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
179         for s in range(i.min(), i.max() + 1):
180             output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
181             logits = output[:, s]
182             if args.deterministic_synthesis:
183                 t_next = logits.argmax(1)
184             else:
185                 dist = torch.distributions.categorical.Categorical(logits=logits)
186                 t_next = dist.sample()
187             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
188
189
190 ######################################################################
191
192
193 def compute_perplexity(model, task, prompt_len, split="train"):
194     with torch.autograd.no_grad():
195         t = model.training
196         model.eval()
197
198         nb_samples, acc_loss = 0, 0.0
199
200         for input in task.batches(split=split):
201             input = input.to(device)
202             output = eval_mygpt(model, input, prompt_len=prompt_len)
203             if args.noncausal_prompt:
204                 d = input.size(1) // 2
205                 loss = F.cross_entropy(output[:, d:].transpose(1, 2), input[:, d:])
206             else:
207                 loss = F.cross_entropy(output.transpose(1, 2), input)
208             acc_loss += loss.item() * input.size(0)
209             nb_samples += input.size(0)
210
211         model.train(t)
212
213         return math.exp(min(100, acc_loss / nb_samples))
214
215
216 ######################################################################
217
218
219 def oneshot_policy_loss(mazes, output, policies, height, width):
220     masks = (mazes == maze.v_empty).unsqueeze(-1)
221     targets = policies.permute(0, 2, 1) * masks
222     output = output * masks
223     return -(output.log_softmax(-1) * targets).sum() / masks.sum()
224
225
226 def oneshot_trace_loss(mazes, output, policies, height, width):
227     masks = mazes == maze.v_empty
228     targets = maze.stationary_densities(
229         mazes.view(-1, height, width), policies.view(-1, 4, height, width)
230     ).flatten(-2)
231     targets = targets * masks
232     output = output.squeeze(-1) * masks
233     return (output - targets).abs().sum() / masks.sum()
234
235
236 def oneshot(gpt, learning_rate_scheduler, task):
237     t = gpt.training
238     gpt.eval()
239
240     if args.oneshot_input == "head":
241         dim_in = args.dim_model
242     elif args.oneshot_input == "deep":
243         dim_in = args.dim_model * args.nb_blocks * 2
244     else:
245         raise ValueError(f"{args.oneshot_input=}")
246
247     if args.oneshot_output == "policy":
248         dim_out = 4
249         compute_loss = oneshot_policy_loss
250     elif args.oneshot_output == "trace":
251         dim_out = 1
252         compute_loss = oneshot_trace_loss
253     else:
254         raise ValueError(f"{args.oneshot_output=}")
255
256     model = nn.Sequential(
257         nn.Linear(dim_in, args.dim_model),
258         nn.ReLU(),
259         nn.Linear(args.dim_model, args.dim_model),
260         nn.ReLU(),
261         nn.Linear(args.dim_model, dim_out),
262     ).to(device)
263
264     learning_rate_scheduler.reset()
265
266     for n_epoch in range(args.nb_epochs):
267         learning_rate = learning_rate_scheduler.get_learning_rate()
268         log_string(f"learning_rate {n_epoch} {learning_rate}")
269
270         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
271
272         acc_train_loss, nb_train_samples = 0, 0
273         for mazes, policies in task.policy_batches(split="train"):
274             output_gpt = eval_mygpt(
275                 gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
276             )
277             output = model(output_gpt)
278
279             loss = compute_loss(mazes, output, policies, task.height, task.width)
280             acc_train_loss += loss.item() * mazes.size(0)
281             nb_train_samples += mazes.size(0)
282
283             optimizer.zero_grad()
284             loss.backward()
285             optimizer.step()
286
287         learning_rate_scheduler.update(n_epoch + 1, acc_train_loss)
288
289         acc_test_loss, nb_test_samples = 0, 0
290         for mazes, policies in task.policy_batches(split="test"):
291             output_gpt = eval_mygpt(
292                 gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
293             )
294             output = model(output_gpt)
295             loss = compute_loss(mazes, output, policies, task.height, task.width)
296             acc_test_loss += loss.item() * mazes.size(0)
297             nb_test_samples += mazes.size(0)
298
299         log_string(
300             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
301         )
302
303         # -------------------
304         mazes = task.test_input[:32, : task.height * task.width]
305         policies = task.test_policies[:32]
306         output_gpt = eval_mygpt(
307             gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
308         )
309         output = model(output_gpt)
310         if args.oneshot_output == "policy":
311             targets = policies.permute(0, 2, 1)
312             scores = (
313                 (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
314             ).float()
315         elif args.oneshot_output == "trace":
316             targets = maze.stationary_densities(
317                 mazes.view(-1, task.height, task.width),
318                 policies.view(-1, 4, task.height, task.width),
319             ).flatten(-2)
320             scores = output
321         else:
322             raise ValueError(f"{args.oneshot_output=}")
323
324         scores = scores.reshape(-1, task.height, task.width)
325         mazes = mazes.reshape(-1, task.height, task.width)
326         targets = targets.reshape(-1, task.height, task.width)
327         filename = (
328             f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
329         )
330         maze.save_image(
331             os.path.join(args.result_dir, filename),
332             mazes=mazes,
333             score_paths=scores,
334             score_truth=targets,
335         )
336         log_string(f"wrote {filename}")
337
338         # -------------------
339
340     gpt.train(t)
341
342
343 ######################################################################
344
345
346 class LearningRateScheduler:
347     def get_learning_rate(self):
348         pass
349
350     def update(self, nb_finished_epochs, loss):
351         pass
352
353     def reset(self):
354         pass
355
356     def get_state(self):
357         return vars(self)
358
359     def set_state(self, state):
360         print(f"{state=}")
361         for k, v in state.items():
362             setattr(self, k, v)
363
364
365 class StepWiseScheduler(LearningRateScheduler):
366     def __init__(self, schedule):
367         self.nb_finished_epochs = 0
368         self.schedule = schedule
369
370     def get_learning_rate(self):
371         return self.schedule[self.nb_finished_epochs]
372
373     def update(self, nb_finished_epochs, loss):
374         self.nb_finished_epochs = nb_finished_epochs
375
376     def reset(self):
377         self.nb_finished_epochs = 0
378
379     def get_state(self):
380         return {"nb_finished_epochs": self.nb_finished_epochs}
381
382
383 class AutoScheduler(LearningRateScheduler):
384     def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
385         self.learning_rate_init = learning_rate_init
386         self.learning_rate = learning_rate_init
387         self.growth = growth
388         self.degrowth = degrowth
389         self.pred_loss = None
390
391     def get_learning_rate(self):
392         return self.learning_rate
393
394     def update(self, nb_finished_epochs, loss):
395         if self.pred_loss is not None:
396             if loss >= self.pred_loss:
397                 self.learning_rate *= self.degrowth
398             else:
399                 self.learning_rate *= self.growth
400         self.pred_loss = loss
401
402     def reset(self):
403         self.learning_rate = self.learning_rate_init
404
405     def get_state(self):
406         return {
407             "learning_rate_init": self.learning_rate_init,
408             "pred_loss": self.pred_loss,
409         }
410
411
412 ######################################################################
413
414
415 class Task:
416     def batches(self, split="train", nb_to_use=-1, desc=None):
417         pass
418
419     def vocabulary_size(self):
420         pass
421
422     def produce_results(self, n_epoch, model):
423         pass
424
425
426 ######################################################################
427
428 import maze
429
430
431 class TaskMaze(Task):
432     def map2seq(self, *m):
433         return torch.cat([x.flatten(1) for x in m], 1)
434
435     def seq2map(self, s):
436         s = s.reshape(s.size(0), -1, self.height, self.width)
437         return (s[:, k] for k in range(s.size(1)))
438
439     def __init__(
440         self,
441         nb_train_samples,
442         nb_test_samples,
443         batch_size,
444         height,
445         width,
446         nb_walls,
447         device=torch.device("cpu"),
448     ):
449         self.batch_size = batch_size
450         self.height = height
451         self.width = width
452         self.device = device
453
454         train_mazes, train_paths, train_policies = maze.create_maze_data(
455             nb_train_samples,
456             height=height,
457             width=width,
458             nb_walls=nb_walls,
459             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
460         )
461         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
462         self.train_policies = train_policies.flatten(-2).to(device)
463
464         test_mazes, test_paths, test_policies = maze.create_maze_data(
465             nb_test_samples,
466             height=height,
467             width=width,
468             nb_walls=nb_walls,
469             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
470         )
471         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
472         self.test_policies = test_policies.flatten(-2).to(device)
473
474         self.nb_codes = self.train_input.max() + 1
475
476     def batches(self, split="train", nb_to_use=-1, desc=None):
477         assert split in {"train", "test"}
478         input = self.train_input if split == "train" else self.test_input
479         if nb_to_use > 0:
480             input = input[:nb_to_use]
481         if desc is None:
482             desc = f"epoch-{split}"
483         for batch in tqdm.tqdm(
484             input.split(self.batch_size), dynamic_ncols=True, desc=desc
485         ):
486             yield batch
487
488     def policy_batches(self, split="train", nb_to_use=-1, desc=None):
489         assert split in {"train", "test"}
490         input = self.train_input if split == "train" else self.test_input
491         policies = self.train_policies if split == "train" else self.test_policies
492         input = input[:, : self.height * self.width]
493         policies = policies * (input != maze.v_wall)[:, None]
494
495         if nb_to_use > 0:
496             input = input[:nb_to_use]
497             policies = policies[:nb_to_use]
498
499         if desc is None:
500             desc = f"epoch-{split}"
501         for batch in tqdm.tqdm(
502             zip(input.split(self.batch_size), policies.split(self.batch_size)),
503             dynamic_ncols=True,
504             desc=desc,
505         ):
506             yield batch
507
508     def vocabulary_size(self):
509         return self.nb_codes
510
511     def compute_error(self, model, split="train", nb_to_use=-1):
512         nb_total, nb_correct = 0, 0
513         for input in task.batches(split, nb_to_use):
514             result = input.clone()
515             ar_mask = result.new_zeros(result.size())
516             ar_mask[:, self.height * self.width :] = 1
517             result *= 1 - ar_mask
518             x, order = shuffle(result, self.height * self.width)
519             masked_inplace_autoregression(
520                 model, self.batch_size, x, ar_mask, order=order
521             )
522             result = reorder(x, order, reverse=True)
523             mazes, paths = self.seq2map(result)
524             nb_correct += maze.path_correctness(mazes, paths).long().sum()
525             nb_total += mazes.size(0)
526
527         return nb_total, nb_correct
528
529     def produce_results(self, n_epoch, model):
530         with torch.autograd.no_grad():
531             t = model.training
532             model.eval()
533
534             train_nb_total, train_nb_correct = self.compute_error(
535                 model, "train", nb_to_use=1000
536             )
537             log_string(
538                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
539             )
540
541             test_nb_total, test_nb_correct = self.compute_error(
542                 model, "test", nb_to_use=1000
543             )
544             log_string(
545                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
546             )
547
548             input = self.test_input[:32]
549             result = input.clone()
550             ar_mask = result.new_zeros(result.size())
551             ar_mask[:, self.height * self.width :] = 1
552             result *= 1 - ar_mask
553             x, order = shuffle(result, self.height * self.width)
554             masked_inplace_autoregression(
555                 model, self.batch_size, x, ar_mask, order=order
556             )
557             result = reorder(x, order, reverse=True)
558
559             mazes, paths = self.seq2map(input)
560             _, predicted_paths = self.seq2map(result)
561             filename = f"result_{n_epoch:04d}.png"
562             maze.save_image(
563                 os.path.join(args.result_dir, filename),
564                 mazes=mazes,
565                 target_paths=paths,
566                 predicted_paths=predicted_paths,
567                 path_correct=maze.path_correctness(mazes, predicted_paths),
568             )
569             log_string(f"wrote {filename}")
570
571             model.train(t)
572
573
574 ######################################################################
575
576 log_string(f"device {device}")
577
578
579 task = TaskMaze(
580     nb_train_samples=args.nb_train_samples,
581     nb_test_samples=args.nb_test_samples,
582     batch_size=args.batch_size,
583     height=args.maze_height,
584     width=args.maze_width,
585     nb_walls=args.maze_nb_walls,
586     device=device,
587 )
588
589
590 vocabulary_size = task.vocabulary_size()
591
592 log_string(f"vocabulary_size {vocabulary_size}")
593
594 ##############################
595
596
597 def noncausal_prompt_amm_generator(d):
598     q = torch.arange(d)[:, None]
599     k = torch.arange(d)[None, :]
600     s = args.maze_height * args.maze_width
601     #    return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
602     return q < k
603
604
605 amm_generator = None
606
607 if args.noncausal_prompt:
608     amm_generator = noncausal_prompt_amm_generator
609
610 model = mygpt.MyGPT(
611     vocabulary_size=vocabulary_size,
612     dim_model=args.dim_model,
613     dim_keys=args.dim_keys,
614     dim_hidden=args.dim_hidden,
615     nb_heads=args.nb_heads,
616     nb_blocks=args.nb_blocks,
617     causal=True,
618     dropout=args.dropout,
619     amm_generator=amm_generator,
620 )
621
622 model.to(device)
623
624 nb_parameters = sum(p.numel() for p in model.parameters())
625 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
626
627 ######################################################################
628
629 if args.learning_rate_schedule == "auto":
630     learning_rate_scheduler = AutoScheduler(args.learning_rate)
631
632 elif args.learning_rate_schedule == "cos":
633     schedule = {}
634     for n_epoch in range(args.nb_epochs):
635         u = n_epoch / args.nb_epochs * math.pi
636         schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
637     learning_rate_scheduler = StepWiseScheduler(schedule)
638     log_string(f"learning_rate_schedule {schedule}")
639
640 else:
641     u = {
642         int(k): float(v)
643         for k, v in [
644             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
645         ]
646     }
647
648     schedule = {}
649     learning_rate = args.learning_rate
650     for n_epoch in range(args.nb_epochs):
651         if n_epoch in u:
652             learning_rate = u[n_epoch]
653         schedule[n_epoch] = learning_rate
654     learning_rate_scheduler = StepWiseScheduler(schedule)
655     log_string(f"learning_rate_schedule {schedule}")
656
657 ######################################################################
658
659 nb_epochs_finished = 0
660
661 if args.no_checkpoint:
662     log_string(f"not trying to load checkpoint.")
663
664 else:
665     try:
666         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
667         checkpoint = torch.load(checkpoint_name)
668         nb_epochs_finished = checkpoint["nb_epochs_finished"]
669         model.load_state_dict(checkpoint["model_state"])
670         learning_rate_scheduler.set_state(checkpoint["learning_rate_scheduler_state"])
671         torch.set_rng_state(checkpoint["rng_state"])
672         if torch.cuda.is_available():
673             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
674
675         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
676
677     except FileNotFoundError:
678         log_string("starting from scratch.")
679
680     # except:
681     # log_string("error when loading the checkpoint.")
682     # exit(1)
683
684 ######################################################################
685
686 token_count = 0
687 for input in task.batches(split="train"):
688     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
689 token_probas = token_count / token_count.sum()
690 entropy = -torch.xlogy(token_probas, token_probas).sum()
691 train_set_perplexity = math.exp(entropy)
692
693 ##############################
694
695 if nb_epochs_finished >= args.nb_epochs:
696     n_epoch = nb_epochs_finished
697     train_perplexity = compute_perplexity(
698         model, task, prompt_len=task.height * task.width, split="train"
699     )
700     test_perplexity = compute_perplexity(
701         model, task, prompt_len=task.height * task.width, split="test"
702     )
703
704     log_string(
705         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
706     )
707
708     task.produce_results(n_epoch, model)
709
710 ##############################
711
712 learning_rate_scheduler.reset()
713
714 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
715     learning_rate = learning_rate_scheduler.get_learning_rate()
716     log_string(f"learning_rate {n_epoch} {learning_rate}")
717
718     if args.optim == "sgd":
719         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
720     elif args.optim == "adam":
721         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
722     elif args.optim == "adamw":
723         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
724     else:
725         raise ValueError(f"{args.optim=}")
726
727     model.train()
728
729     nb_train_samples, acc_train_loss = 0, 0.0
730
731     for input in task.batches(split="train"):
732         input = input.to(device)
733         output = eval_mygpt(model, input, prompt_len=task.height * task.width)
734         if args.noncausal_prompt:
735             d = input.size(1) // 2
736             loss = F.cross_entropy(output[:, d:].transpose(1, 2), input[:, d:])
737         else:
738             loss = F.cross_entropy(output.transpose(1, 2), input)
739         acc_train_loss += loss.item() * input.size(0)
740         nb_train_samples += input.size(0)
741
742         optimizer.zero_grad()
743         loss.backward()
744         optimizer.step()
745
746     learning_rate_scheduler.update(n_epoch + 1, acc_train_loss)
747
748     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
749     test_perplexity = compute_perplexity(
750         model, task, prompt_len=task.height * task.width, split="test"
751     )
752
753     log_string(
754         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
755     )
756
757     task.produce_results(n_epoch, model)
758
759     checkpoint = {
760         "nb_epochs_finished": n_epoch + 1,
761         "model_state": model.state_dict(),
762         "learning_rate_scheduler_state": learning_rate_scheduler.get_state(),
763         "rng_state": torch.get_rng_state(),
764     }
765
766     if torch.cuda.is_available():
767         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
768
769     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
770     torch.save(checkpoint, checkpoint_name)
771     log_string(f"saved checkpoint {checkpoint_name}")
772
773 ######################################################################
774
775 if args.oneshot:
776     oneshot(model, learning_rate_scheduler, task)
777
778 ######################################################################