Update
[beaver.git] / beaver.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, itertools, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
30
31 parser.add_argument("--log_filename", type=str, default="train.log")
32
33 parser.add_argument("--result_dir", type=str, default="results_default")
34
35 parser.add_argument("--seed", type=int, default=0)
36
37 parser.add_argument("--nb_epochs", type=int, default=25)
38
39 parser.add_argument("--nb_train_samples", type=int, default=200000)
40
41 parser.add_argument("--nb_test_samples", type=int, default=50000)
42
43 parser.add_argument("--batch_size", type=int, default=25)
44
45 parser.add_argument("--optim", type=str, default="adam")
46
47 parser.add_argument("--learning_rate", type=float, default=1e-3)
48
49 parser.add_argument(
50     "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
51 )
52
53 parser.add_argument("--dim_model", type=int, default=512)
54
55 parser.add_argument("--dim_keys", type=int, default=64)
56
57 parser.add_argument("--dim_hidden", type=int, default=2048)
58
59 parser.add_argument("--nb_heads", type=int, default=8)
60
61 parser.add_argument("--nb_blocks", type=int, default=12)
62
63 parser.add_argument("--dropout", type=float, default=0.1)
64
65 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
66
67 parser.add_argument("--random_regression_order", action="store_true", default=False)
68
69 parser.add_argument("--noncausal_prompt", action="store_true", default=False)
70
71 parser.add_argument("--no_checkpoint", action="store_true", default=False)
72
73 parser.add_argument("--overwrite_results", action="store_true", default=False)
74
75 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
76
77 ##############################
78 # maze options
79
80 parser.add_argument("--maze_height", type=int, default=13)
81
82 parser.add_argument("--maze_width", type=int, default=21)
83
84 parser.add_argument("--maze_nb_walls", type=int, default=15)
85
86 ##############################
87 # one-shot prediction
88
89 parser.add_argument("--oneshot", action="store_true", default=False)
90
91 parser.add_argument("--oneshot_input", type=str, default="head")
92
93 parser.add_argument("--oneshot_output", type=str, default="trace")
94
95 ######################################################################
96
97 args = parser.parse_args()
98
99 try:
100     os.mkdir(args.result_dir)
101 except FileExistsError:
102     if not args.overwrite_results:
103         print(f"result directory {args.result_dir} already exists")
104         exit(1)
105
106 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
107
108 if args.seed >= 0:
109     # torch.backends.cudnn.deterministic = True
110     # torch.backends.cudnn.benchmark = False
111     # torch.use_deterministic_algorithms(True)
112     torch.manual_seed(args.seed)
113     if torch.cuda.is_available():
114         torch.cuda.manual_seed_all(args.seed)
115
116 ######################################################################
117
118
119 def log_string(s):
120     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
121
122     if log_file is not None:
123         log_file.write(t + s + "\n")
124         log_file.flush()
125
126     print(t + s)
127     sys.stdout.flush()
128
129
130 for n in vars(args):
131     log_string(f"args.{n} {getattr(args, n)}")
132
133 ######################################################################
134
135
136 def generation_order(x, fixed_len=0):
137     if args.random_regression_order:
138         order = torch.rand(x.size(), device=x.device)
139         order[:, :fixed_len] = torch.arange(-fixed_len, 0, device=x.device)
140         order = order.sort(1).indices
141     else:
142         order = (
143             torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(x.size(0), -1)
144         )
145     return order
146
147
148 def reorder(x, order, reverse=False):  # x is NxTxD1x...xDk, order is NxT'
149     u = x.reshape(x.size()[:2] + (-1,))
150     order = order.unsqueeze(-1).expand(-1, -1, u.size(-1))
151     if reverse:
152         v = u.new(u.size()).scatter_(1, order, u)
153     else:
154         v = u.gather(1, order)
155     v = v.reshape(v.size()[:2] + x.size()[2:])
156     return v
157
158
159 def shuffle(x, fixed_len):
160     order = generation_order(x, fixed_len)
161     return reorder(x, order), order
162
163
164 def eval_mygpt(model, input, mode="standard", fixed_len=0):
165     x, order = shuffle(input, fixed_len)
166     x = model(mygpt.BracketedSequence(x), mode=mode, order=order).x
167     return reorder(x, order, reverse=True)
168
169
170 ######################################################################
171
172 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
173 # tokens that should be generated
174
175
176 def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
177     for input, ar_mask, order in zip(
178         input.split(batch_size), ar_mask.split(batch_size), order.split(batch_size)
179     ):
180         i = (ar_mask.sum(0) > 0).nonzero()
181         if i.min() > 0:
182             # Needed to initialize the model's cache
183             model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
184         for s in range(i.min(), i.max() + 1):
185             output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
186             logits = output[:, s]
187             if args.deterministic_synthesis:
188                 t_next = logits.argmax(1)
189             else:
190                 dist = torch.distributions.categorical.Categorical(logits=logits)
191                 t_next = dist.sample()
192             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
193
194
195 ######################################################################
196
197
198 def compute_perplexity(model, task, fixed_len, split="train"):
199     with torch.autograd.no_grad():
200         t = model.training
201         model.eval()
202
203         nb_samples, acc_loss = 0, 0.0
204
205         for input in task.batches(split=split):
206             input = input.to(device)
207             output = eval_mygpt(model, input, fixed_len=fixed_len)
208             loss = F.cross_entropy(output.transpose(1, 2), input)
209             acc_loss += loss.item() * input.size(0)
210             nb_samples += input.size(0)
211
212         model.train(t)
213
214         return math.exp(min(100, acc_loss / nb_samples))
215
216
217 ######################################################################
218
219
220 def oneshot_policy_loss(mazes, output, policies, height, width):
221     masks = (mazes == maze.v_empty).unsqueeze(-1)
222     targets = policies.permute(0, 2, 1) * masks
223     output = output * masks
224     return -(output.log_softmax(-1) * targets).sum() / masks.sum()
225
226
227 def oneshot_trace_loss(mazes, output, policies, height, width):
228     masks = mazes == maze.v_empty
229     targets = maze.stationary_densities(
230         mazes.view(-1, height, width), policies.view(-1, 4, height, width)
231     ).flatten(-2)
232     targets = targets * masks
233     output = output.squeeze(-1) * masks
234     return (output - targets).abs().sum() / masks.sum()
235
236
237 def oneshot(gpt, task):
238     t = gpt.training
239     gpt.eval()
240
241     if args.oneshot_input == "head":
242         dim_in = args.dim_model
243     elif args.oneshot_input == "deep":
244         dim_in = args.dim_model * args.nb_blocks * 2
245     else:
246         raise ValueError(f"{args.oneshot_input=}")
247
248     if args.oneshot_output == "policy":
249         dim_out = 4
250         compute_loss = oneshot_policy_loss
251     elif args.oneshot_output == "trace":
252         dim_out = 1
253         compute_loss = oneshot_trace_loss
254     else:
255         raise ValueError(f"{args.oneshot_output=}")
256
257     model = nn.Sequential(
258         nn.Linear(dim_in, args.dim_model),
259         nn.ReLU(),
260         nn.Linear(args.dim_model, args.dim_model),
261         nn.ReLU(),
262         nn.Linear(args.dim_model, dim_out),
263     ).to(device)
264
265     for n_epoch in range(args.nb_epochs):
266         learning_rate = learning_rate_schedule[n_epoch]
267         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
268
269         acc_train_loss, nb_train_samples = 0, 0
270         for mazes, policies in task.policy_batches(split="train"):
271             output_gpt = eval_mygpt(
272                 gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
273             )
274             output = model(output_gpt)
275
276             loss = compute_loss(mazes, output, policies, task.height, task.width)
277             acc_train_loss += loss.item() * mazes.size(0)
278             nb_train_samples += mazes.size(0)
279
280             optimizer.zero_grad()
281             loss.backward()
282             optimizer.step()
283
284         acc_test_loss, nb_test_samples = 0, 0
285         for mazes, policies in task.policy_batches(split="test"):
286             output_gpt = eval_mygpt(
287                 gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
288             )
289             output = model(output_gpt)
290             loss = compute_loss(mazes, output, policies, task.height, task.width)
291             acc_test_loss += loss.item() * mazes.size(0)
292             nb_test_samples += mazes.size(0)
293
294         log_string(
295             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
296         )
297
298         # -------------------
299         mazes = task.test_input[:32, : task.height * task.width]
300         policies = task.test_policies[:32]
301         output_gpt = eval_mygpt(
302             gpt, mazes, mode=args.oneshot_input, fixed_len=task.height * task.width
303         )
304         output = model(output_gpt)
305         if args.oneshot_output == "policy":
306             targets = policies.permute(0, 2, 1)
307             scores = (
308                 (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
309             ).float()
310         elif args.oneshot_output == "trace":
311             targets = maze.stationary_densities(
312                 mazes.view(-1, task.height, task.width),
313                 policies.view(-1, 4, task.height, task.width),
314             ).flatten(-2)
315             scores = output
316         else:
317             raise ValueError(f"{args.oneshot_output=}")
318
319         scores = scores.reshape(-1, task.height, task.width)
320         mazes = mazes.reshape(-1, task.height, task.width)
321         targets = targets.reshape(-1, task.height, task.width)
322         filename = (
323             f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png"
324         )
325         maze.save_image(
326             os.path.join(args.result_dir, filename),
327             mazes=mazes,
328             score_paths=scores,
329             score_truth=targets,
330         )
331         log_string(f"wrote {filename}")
332
333         # -------------------
334
335     gpt.train(t)
336
337
338 ######################################################################
339
340
341 class Task:
342     def batches(self, split="train", nb_to_use=-1, desc=None):
343         pass
344
345     def vocabulary_size(self):
346         pass
347
348     def produce_results(self, n_epoch, model):
349         pass
350
351
352 ######################################################################
353
354 import maze
355
356
357 class TaskMaze(Task):
358     def map2seq(self, *m):
359         return torch.cat([x.flatten(1) for x in m], 1)
360
361     def seq2map(self, s):
362         s = s.reshape(s.size(0), -1, self.height, self.width)
363         return (s[:, k] for k in range(s.size(1)))
364
365     def __init__(
366         self,
367         nb_train_samples,
368         nb_test_samples,
369         batch_size,
370         height,
371         width,
372         nb_walls,
373         device=torch.device("cpu"),
374     ):
375         self.batch_size = batch_size
376         self.height = height
377         self.width = width
378         self.device = device
379
380         train_mazes, train_paths, train_policies = maze.create_maze_data(
381             nb_train_samples,
382             height=height,
383             width=width,
384             nb_walls=nb_walls,
385             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
386         )
387         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
388         self.train_policies = train_policies.flatten(-2).to(device)
389
390         test_mazes, test_paths, test_policies = maze.create_maze_data(
391             nb_test_samples,
392             height=height,
393             width=width,
394             nb_walls=nb_walls,
395             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
396         )
397         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
398         self.test_policies = test_policies.flatten(-2).to(device)
399
400         self.nb_codes = self.train_input.max() + 1
401
402     def batches(self, split="train", nb_to_use=-1, desc=None):
403         assert split in {"train", "test"}
404         input = self.train_input if split == "train" else self.test_input
405         if nb_to_use > 0:
406             input = input[:nb_to_use]
407         if desc is None:
408             desc = f"epoch-{split}"
409         for batch in tqdm.tqdm(
410             input.split(self.batch_size), dynamic_ncols=True, desc=desc
411         ):
412             yield batch
413
414     def policy_batches(self, split="train", nb_to_use=-1, desc=None):
415         assert split in {"train", "test"}
416         input = self.train_input if split == "train" else self.test_input
417         policies = self.train_policies if split == "train" else self.test_policies
418         input = input[:, : self.height * self.width]
419         policies = policies * (input != maze.v_wall)[:, None]
420
421         if nb_to_use > 0:
422             input = input[:nb_to_use]
423             policies = policies[:nb_to_use]
424
425         if desc is None:
426             desc = f"epoch-{split}"
427         for batch in tqdm.tqdm(
428             zip(input.split(self.batch_size), policies.split(self.batch_size)),
429             dynamic_ncols=True,
430             desc=desc,
431         ):
432             yield batch
433
434     def vocabulary_size(self):
435         return self.nb_codes
436
437     def compute_error(self, model, split="train", nb_to_use=-1):
438         nb_total, nb_correct = 0, 0
439         for input in task.batches(split, nb_to_use):
440             result = input.clone()
441             ar_mask = result.new_zeros(result.size())
442             ar_mask[:, self.height * self.width :] = 1
443             result *= 1 - ar_mask
444             x, order = shuffle(result, self.height * self.width)
445             masked_inplace_autoregression(
446                 model, self.batch_size, x, ar_mask, order=order
447             )
448             result = reorder(x, order, reverse=True)
449             mazes, paths = self.seq2map(result)
450             nb_correct += maze.path_correctness(mazes, paths).long().sum()
451             nb_total += mazes.size(0)
452
453         return nb_total, nb_correct
454
455     def produce_results(self, n_epoch, model):
456         with torch.autograd.no_grad():
457             t = model.training
458             model.eval()
459
460             train_nb_total, train_nb_correct = self.compute_error(
461                 model, "train", nb_to_use=1000
462             )
463             log_string(
464                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
465             )
466
467             test_nb_total, test_nb_correct = self.compute_error(
468                 model, "test", nb_to_use=1000
469             )
470             log_string(
471                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
472             )
473
474             input = self.test_input[:32]
475             result = input.clone()
476             ar_mask = result.new_zeros(result.size())
477             ar_mask[:, self.height * self.width :] = 1
478             result *= 1 - ar_mask
479             x, order = shuffle(result, self.height * self.width)
480             masked_inplace_autoregression(
481                 model, self.batch_size, x, ar_mask, order=order
482             )
483             result = reorder(x, order, reverse=True)
484
485             mazes, paths = self.seq2map(input)
486             _, predicted_paths = self.seq2map(result)
487             filename = f"result_{n_epoch:04d}.png"
488             maze.save_image(
489                 os.path.join(args.result_dir, filename),
490                 mazes=mazes,
491                 target_paths=paths,
492                 predicted_paths=predicted_paths,
493                 path_correct=maze.path_correctness(mazes, predicted_paths),
494             )
495             log_string(f"wrote {filename}")
496
497             model.train(t)
498
499
500 ######################################################################
501
502 log_string(f"device {device}")
503
504
505 task = TaskMaze(
506     nb_train_samples=args.nb_train_samples,
507     nb_test_samples=args.nb_test_samples,
508     batch_size=args.batch_size,
509     height=args.maze_height,
510     width=args.maze_width,
511     nb_walls=args.maze_nb_walls,
512     device=device,
513 )
514
515
516 vocabulary_size = task.vocabulary_size()
517
518 log_string(f"vocabulary_size {vocabulary_size}")
519
520 ##############################
521
522 amm_generator = None
523
524 if args.noncausal_prompt:
525     amm_generator = lambda d: torch.logical_and(
526         torch.arange(d)[None, None, :, None] < torch.arange(d)[None, None, None, :],
527         torch.arange(d)[None, None, :, None] >= d // 2,
528     )
529
530 model = mygpt.MyGPT(
531     vocabulary_size=vocabulary_size,
532     dim_model=args.dim_model,
533     dim_keys=args.dim_keys,
534     dim_hidden=args.dim_hidden,
535     nb_heads=args.nb_heads,
536     nb_blocks=args.nb_blocks,
537     causal=True,
538     dropout=args.dropout,
539     amm_generator=amm_generator,
540 )
541
542 model.to(device)
543
544 nb_parameters = sum(p.numel() for p in model.parameters())
545 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
546
547 ######################################################################
548
549 nb_epochs_finished = 0
550
551 if args.no_checkpoint:
552     log_string(f"not trying to load checkpoint.")
553
554 else:
555     try:
556         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
557         checkpoint = torch.load(checkpoint_name)
558         nb_epochs_finished = checkpoint["nb_epochs_finished"]
559         model.load_state_dict(checkpoint["model_state"])
560         torch.set_rng_state(checkpoint["rng_state"])
561         if torch.cuda.is_available():
562             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
563
564         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
565
566     except FileNotFoundError:
567         log_string("starting from scratch.")
568
569     except:
570         log_string("error when loading the checkpoint.")
571         exit(1)
572
573 ######################################################################
574
575 token_count = 0
576 for input in task.batches(split="train"):
577     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
578 token_probas = token_count / token_count.sum()
579 entropy = -torch.xlogy(token_probas, token_probas).sum()
580 train_set_perplexity = math.exp(entropy)
581
582 ##############################
583
584 if args.learning_rate_schedule == "cos":
585     learning_rate_schedule = {}
586     for n_epoch in range(args.nb_epochs):
587         u = n_epoch / args.nb_epochs * math.pi
588         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
589 else:
590     u = {
591         int(k): float(v)
592         for k, v in [
593             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
594         ]
595     }
596
597     learning_rate_schedule = {}
598     learning_rate = args.learning_rate
599     for n_epoch in range(args.nb_epochs):
600         if n_epoch in u:
601             learning_rate = u[n_epoch]
602         learning_rate_schedule[n_epoch] = learning_rate
603
604 log_string(f"learning_rate_schedule {learning_rate_schedule}")
605
606 ##############################
607
608 if nb_epochs_finished >= args.nb_epochs:
609     n_epoch = nb_epochs_finished
610     train_perplexity = compute_perplexity(
611         model, task, fixed_len=task.height * task.width, split="train"
612     )
613     test_perplexity = compute_perplexity(
614         model, task, fixed_len=task.height * task.width, split="test"
615     )
616
617     log_string(
618         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
619     )
620
621     task.produce_results(n_epoch, model)
622
623 ##############################
624
625 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
626     learning_rate = learning_rate_schedule[n_epoch]
627
628     log_string(f"learning_rate {learning_rate}")
629
630     if args.optim == "sgd":
631         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
632     elif args.optim == "adam":
633         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
634     elif args.optim == "adamw":
635         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
636     else:
637         raise ValueError(f"{args.optim=}")
638
639     model.train()
640
641     nb_train_samples, acc_train_loss = 0, 0.0
642
643     for input in task.batches(split="train"):
644         input = input.to(device)
645         output = eval_mygpt(
646             model, input, mode=args.oneshot_input, fixed_len=task.height * task.width
647         )
648         loss = F.cross_entropy(output.transpose(1, 2), input)
649         acc_train_loss += loss.item() * input.size(0)
650         nb_train_samples += input.size(0)
651
652         optimizer.zero_grad()
653         loss.backward()
654         optimizer.step()
655
656     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
657     test_perplexity = compute_perplexity(
658         model, task, fixed_len=task.height * task.width, split="test"
659     )
660
661     log_string(
662         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
663     )
664
665     task.produce_results(n_epoch, model)
666
667     checkpoint = {
668         "nb_epochs_finished": n_epoch + 1,
669         "model_state": model.state_dict(),
670         "rng_state": torch.get_rng_state(),
671     }
672
673     if torch.cuda.is_available():
674         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
675
676     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
677     torch.save(checkpoint, checkpoint_name)
678     log_string(f"saved checkpoint {checkpoint_name}")
679
680 ######################################################################
681
682 if args.oneshot:
683     oneshot(model, task)
684
685 ######################################################################