6ec0fb290e2109077b6aefe1a2ae63d032e755b2
[beaver.git] / beaver.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 # torch.backends.cuda.matmul.allow_tf23
9 # torch.autocast(torch.bfloat16)
10
11 import math, sys, argparse, time, tqdm, itertools, os
12
13 import torch, torchvision
14 from torch import nn
15 from torch.nn import functional as F
16
17 import mygpt, tensorstack
18
19 ######################################################################
20
21 if torch.cuda.is_available():
22     device = torch.device("cuda")
23     torch.backends.cuda.matmul.allow_tf32 = True
24 else:
25     device = torch.device("cpu")
26
27 ######################################################################
28
29 parser = argparse.ArgumentParser(description="A maze shortest path solving with a GPT.")
30
31 parser.add_argument("--log_filename", type=str, default="train.log")
32
33 parser.add_argument("--result_dir", type=str, default="results_default")
34
35 parser.add_argument("--seed", type=int, default=0)
36
37 parser.add_argument("--nb_epochs", type=int, default=25)
38
39 parser.add_argument("--nb_train_samples", type=int, default=200000)
40
41 parser.add_argument("--nb_test_samples", type=int, default=50000)
42
43 parser.add_argument("--batch_size", type=int, default=25)
44
45 parser.add_argument("--optim", type=str, default="adam")
46
47 parser.add_argument("--learning_rate", type=float, default=1e-3)
48
49 parser.add_argument(
50     "--learning_rate_schedule", type=str, default="10: 2e-4,20: 4e-5,30: 8e-6"
51 )
52
53 parser.add_argument("--dim_model", type=int, default=512)
54
55 parser.add_argument("--dim_keys", type=int, default=64)
56
57 parser.add_argument("--dim_hidden", type=int, default=2048)
58
59 parser.add_argument("--nb_heads", type=int, default=8)
60
61 parser.add_argument("--nb_blocks", type=int, default=12)
62
63 parser.add_argument("--dropout", type=float, default=0.1)
64
65 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
66
67 parser.add_argument("--no_checkpoint", action="store_true", default=False)
68
69 parser.add_argument("--overwrite_results", action="store_true", default=False)
70
71 parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
72
73 ##############################
74 # maze options
75
76 parser.add_argument("--maze_height", type=int, default=13)
77
78 parser.add_argument("--maze_width", type=int, default=21)
79
80 parser.add_argument("--maze_nb_walls", type=int, default=15)
81
82 ##############################
83 # one-shot prediction
84
85 parser.add_argument("--oneshot", action="store_true", default=False)
86
87 parser.add_argument("--oneshot_input", type=str, default="head")
88
89 parser.add_argument("--oneshot_output", type=str, default="trace")
90
91 ######################################################################
92
93 args = parser.parse_args()
94
95 try:
96     os.mkdir(args.result_dir)
97 except FileExistsError:
98     if not args.overwrite_results:
99         print(f"result directory {args.result_dir} already exists")
100         exit(1)
101
102 log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
103
104 if args.seed >= 0:
105     # torch.backends.cudnn.deterministic = True
106     # torch.backends.cudnn.benchmark = False
107     # torch.use_deterministic_algorithms(True)
108     torch.manual_seed(args.seed)
109     if torch.cuda.is_available():
110         torch.cuda.manual_seed_all(args.seed)
111
112 ######################################################################
113
114
115 def log_string(s):
116     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
117
118     if log_file is not None:
119         log_file.write(t + s + "\n")
120         log_file.flush()
121
122     print(t + s)
123     sys.stdout.flush()
124
125
126 for n in vars(args):
127     log_string(f"args.{n} {getattr(args, n)}")
128
129 ######################################################################
130
131
132 def random_order(result, fixed_len):
133     order = torch.rand(result.size(), device=result.device)
134     order[:, :fixed_len] = torch.linspace(-2, -1, fixed_len, device=order.device)
135     return order.sort(1).indices
136
137
138 def shuffle(x, order, reorder=False):
139     if x.dim() == 3:
140         order = order.unsqueeze(-1).expand(-1, -1, x.size(-1))
141     if reorder:
142         y = x.new(x.size())
143         y.scatter_(1, order, x)
144         return y
145     else:
146         return x.gather(1, order)
147
148
149 # ar_mask is a Boolean matrix of same shape as input, with 1s on the
150 # tokens that should be generated
151
152
153 def masked_inplace_autoregression(model, batch_size, input, ar_mask, order=None):
154     for input, ar_mask in zip(input.split(batch_size), ar_mask.split(batch_size)):
155         i = (ar_mask.sum(0) > 0).nonzero()
156         if i.min() > 0:
157             # Needed to initialize the model's cache
158             model(mygpt.BracketedSequence(input, 0, i.min()), order=order)
159         for s in range(i.min(), i.max() + 1):
160             output = model(mygpt.BracketedSequence(input, s, 1), order=order).x
161             logits = output[:, s]
162             if args.deterministic_synthesis:
163                 t_next = logits.argmax(1)
164             else:
165                 dist = torch.distributions.categorical.Categorical(logits=logits)
166                 t_next = dist.sample()
167             input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
168
169
170 ######################################################################
171
172
173 def compute_perplexity(model, split="train"):
174     with torch.autograd.no_grad():
175         t = model.training
176         model.eval()
177
178         nb_samples, acc_loss = 0, 0.0
179
180         for input in task.batches(split=split):
181             input = input.to(device)
182             order = random_order(input, task.height * task.width)
183             input = shuffle(input, order)
184             output = model(mygpt.BracketedSequence(input), order=order).x
185             loss = F.cross_entropy(output.transpose(1, 2), input)
186             acc_loss += loss.item() * input.size(0)
187             nb_samples += input.size(0)
188
189         model.train(t)
190
191         return math.exp(min(100, acc_loss / nb_samples))
192
193
194 ######################################################################
195
196
197 def oneshot_policy_loss(mazes, output, policies, height, width):
198     masks = (mazes == maze.v_empty).unsqueeze(-1)
199     targets = policies.permute(0, 2, 1) * masks
200     output = output * masks
201     return -(output.log_softmax(-1) * targets).sum() / masks.sum()
202
203
204 def oneshot_trace_loss(mazes, output, policies, height, width):
205     masks = mazes == maze.v_empty
206     targets = maze.stationary_densities(
207         mazes.view(-1, height, width), policies.view(-1, 4, height, width)
208     ).flatten(-2)
209     targets = targets * masks
210     output = output.squeeze(-1) * masks
211     return (output - targets).abs().sum() / masks.sum()
212
213
214 def oneshot(gpt, task):
215     t = gpt.training
216     gpt.eval()
217
218     if args.oneshot_input == "head":
219         dim_in = args.dim_model
220     elif args.oneshot_input == "deep":
221         dim_in = args.dim_model * args.nb_blocks * 2
222     else:
223         raise ValueError(f"{args.oneshot_input=}")
224
225     if args.oneshot_output == "policy":
226         dim_out = 4
227         compute_loss = oneshot_policy_loss
228     elif args.oneshot_output == "trace":
229         dim_out = 1
230         compute_loss = oneshot_trace_loss
231     else:
232         raise ValueError(f"{args.oneshot_output=}")
233
234     model = nn.Sequential(
235         nn.Linear(dim_in, args.dim_model),
236         nn.ReLU(),
237         nn.Linear(args.dim_model, args.dim_model),
238         nn.ReLU(),
239         nn.Linear(args.dim_model, dim_out),
240     ).to(device)
241
242     for n_epoch in range(args.nb_epochs):
243         learning_rate = learning_rate_schedule[n_epoch]
244         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
245
246         acc_train_loss, nb_train_samples = 0, 0
247         for mazes, policies in task.policy_batches(split="train"):
248             order = random_order(input, task.height * task.width)
249             x = shuffle(mazes, order)
250             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
251             output_gpt = shuffle(x, order, reorder=True)
252             output = model(output_gpt)
253
254             loss = compute_loss(mazes, output, policies, task.height, task.width)
255             acc_train_loss += loss.item() * mazes.size(0)
256             nb_train_samples += mazes.size(0)
257
258             optimizer.zero_grad()
259             loss.backward()
260             optimizer.step()
261
262         acc_test_loss, nb_test_samples = 0, 0
263         for mazes, policies in task.policy_batches(split="test"):
264             order = random_order(input, task.height * task.width)
265             x = shuffle(mazes, order)
266             x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
267             output_gpt = shuffle(x, order, reorder=True)
268             output = model(output_gpt)
269             loss = compute_loss(mazes, output, policies, task.height, task.width)
270             acc_test_loss += loss.item() * mazes.size(0)
271             nb_test_samples += mazes.size(0)
272
273         log_string(
274             f"diff_ce {n_epoch} train {acc_train_loss/nb_train_samples} test {acc_test_loss/nb_test_samples}"
275         )
276
277         # -------------------
278         mazes = task.test_input[:32, : task.height * task.width]
279         policies = task.test_policies[:32]
280         order = random_order(input, task.height * task.width)
281         x = shuffle(mazes, order)
282         x = gpt(mygpt.BracketedSequence(x), mode=args.oneshot_input, order=order).x
283         output_gpt = shuffle(x, order, reorder=True)
284         output = model(output_gpt)
285         if args.oneshot_output == "policy":
286             targets = policies.permute(0, 2, 1)
287             scores = (
288                 (F.one_hot(output.argmax(-1), num_classes=4) * targets).sum(-1) == 0
289             ).float()
290         elif args.oneshot_output == "trace":
291             targets = maze.stationary_densities(
292                 mazes.view(-1, task.height, task.width),
293                 policies.view(-1, 4, task.height, task.width),
294             ).flatten(-2)
295             scores = output
296         else:
297             raise ValueError(f"{args.oneshot_output=}")
298
299         scores = scores.reshape(-1, task.height, task.width)
300         mazes = mazes.reshape(-1, task.height, task.width)
301         targets = targets.reshape(-1, task.height, task.width)
302         maze.save_image(
303             os.path.join(
304                 args.result_dir,
305                 f"oneshot_{args.oneshot_input}_{args.oneshot_output}_{n_epoch:04d}.png",
306             ),
307             mazes=mazes,
308             score_paths=scores,
309             score_truth=targets,
310         )
311         # -------------------
312
313     gpt.train(t)
314
315
316 ######################################################################
317
318
319 class Task:
320     def batches(self, split="train", nb_to_use=-1, desc=None):
321         pass
322
323     def vocabulary_size(self):
324         pass
325
326     def produce_results(self, n_epoch, model):
327         pass
328
329
330 ######################################################################
331
332 import maze
333
334
335 class TaskMaze(Task):
336     def map2seq(self, *m):
337         return torch.cat([x.flatten(1) for x in m], 1)
338
339     def seq2map(self, s):
340         s = s.reshape(s.size(0), -1, self.height, self.width)
341         return (s[:, k] for k in range(s.size(1)))
342
343     def __init__(
344         self,
345         nb_train_samples,
346         nb_test_samples,
347         batch_size,
348         height,
349         width,
350         nb_walls,
351         device=torch.device("cpu"),
352     ):
353         self.batch_size = batch_size
354         self.height = height
355         self.width = width
356         self.device = device
357
358         train_mazes, train_paths, train_policies = maze.create_maze_data(
359             nb_train_samples,
360             height=height,
361             width=width,
362             nb_walls=nb_walls,
363             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
364         )
365         self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
366         self.train_policies = train_policies.flatten(-2).to(device)
367
368         test_mazes, test_paths, test_policies = maze.create_maze_data(
369             nb_test_samples,
370             height=height,
371             width=width,
372             nb_walls=nb_walls,
373             progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
374         )
375         self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
376         self.test_policies = test_policies.flatten(-2).to(device)
377
378         self.nb_codes = self.train_input.max() + 1
379
380     def batches(self, split="train", nb_to_use=-1, desc=None):
381         assert split in {"train", "test"}
382         input = self.train_input if split == "train" else self.test_input
383         if nb_to_use > 0:
384             input = input[:nb_to_use]
385         if desc is None:
386             desc = f"epoch-{split}"
387         for batch in tqdm.tqdm(
388             input.split(self.batch_size), dynamic_ncols=True, desc=desc
389         ):
390             yield batch
391
392     def policy_batches(self, split="train", nb_to_use=-1, desc=None):
393         assert split in {"train", "test"}
394         input = self.train_input if split == "train" else self.test_input
395         policies = self.train_policies if split == "train" else self.test_policies
396         input = input[:, : self.height * self.width]
397         policies = policies * (input != maze.v_wall)[:, None]
398
399         if nb_to_use > 0:
400             input = input[:nb_to_use]
401             policies = policies[:nb_to_use]
402
403         if desc is None:
404             desc = f"epoch-{split}"
405         for batch in tqdm.tqdm(
406             zip(input.split(self.batch_size), policies.split(self.batch_size)),
407             dynamic_ncols=True,
408             desc=desc,
409         ):
410             yield batch
411
412     def vocabulary_size(self):
413         return self.nb_codes
414
415     def compute_error(self, model, split="train", nb_to_use=-1):
416         nb_total, nb_correct = 0, 0
417         for input in task.batches(split, nb_to_use):
418             result = input.clone()
419             ar_mask = result.new_zeros(result.size())
420             ar_mask[:, self.height * self.width :] = 1
421             result *= 1 - ar_mask
422             order = random_order(result, self.height * self.width)
423             masked_inplace_autoregression(
424                 model, self.batch_size, result, ar_mask, order=order
425             )
426             result = shuffle(result, order, reorder=True)
427             mazes, paths = self.seq2map(result)
428             nb_correct += maze.path_correctness(mazes, paths).long().sum()
429             nb_total += mazes.size(0)
430
431         return nb_total, nb_correct
432
433     def produce_results(self, n_epoch, model):
434         with torch.autograd.no_grad():
435             t = model.training
436             model.eval()
437
438             train_nb_total, train_nb_correct = self.compute_error(
439                 model, "train", nb_to_use=1000
440             )
441             log_string(
442                 f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
443             )
444
445             test_nb_total, test_nb_correct = self.compute_error(
446                 model, "test", nb_to_use=1000
447             )
448             log_string(
449                 f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
450             )
451
452             input = self.test_input[:32]
453             result = input.clone()
454             ar_mask = result.new_zeros(result.size())
455             ar_mask[:, self.height * self.width :] = 1
456             result *= 1 - ar_mask
457             masked_inplace_autoregression(model, self.batch_size, result, ar_mask)
458
459             mazes, paths = self.seq2map(input)
460             _, predicted_paths = self.seq2map(result)
461             maze.save_image(
462                 os.path.join(args.result_dir, f"result_{n_epoch:04d}.png"),
463                 mazes=mazes,
464                 target_paths=paths,
465                 predicted_paths=predicted_paths,
466                 path_correct=maze.path_correctness(mazes, predicted_paths),
467             )
468
469             model.train(t)
470
471
472 ######################################################################
473
474 log_string(f"device {device}")
475
476
477 task = TaskMaze(
478     nb_train_samples=args.nb_train_samples,
479     nb_test_samples=args.nb_test_samples,
480     batch_size=args.batch_size,
481     height=args.maze_height,
482     width=args.maze_width,
483     nb_walls=args.maze_nb_walls,
484     device=device,
485 )
486
487
488 vocabulary_size = task.vocabulary_size()
489
490 log_string(f"vocabulary_size {vocabulary_size}")
491
492 ##############################
493
494 model = mygpt.MyGPT(
495     vocabulary_size=vocabulary_size,
496     dim_model=args.dim_model,
497     dim_keys=args.dim_keys,
498     dim_hidden=args.dim_hidden,
499     nb_heads=args.nb_heads,
500     nb_blocks=args.nb_blocks,
501     causal=True,
502     dropout=args.dropout,
503 )
504
505 model.to(device)
506
507 nb_parameters = sum(p.numel() for p in model.parameters())
508 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
509
510 ######################################################################
511
512 nb_epochs_finished = 0
513
514 if args.no_checkpoint:
515     log_string(f"not trying to load checkpoint.")
516
517 else:
518     try:
519         checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
520         checkpoint = torch.load(checkpoint_name)
521         nb_epochs_finished = checkpoint["nb_epochs_finished"]
522         model.load_state_dict(checkpoint["model_state"])
523         torch.set_rng_state(checkpoint["rng_state"])
524         if torch.cuda.is_available():
525             torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
526
527         log_string(f"checkpoint loaded with {nb_epochs_finished} epochs finished.")
528
529     except FileNotFoundError:
530         log_string("starting from scratch.")
531
532     except:
533         log_string("error when loading the checkpoint.")
534         exit(1)
535
536 ######################################################################
537
538 token_count = 0
539 for input in task.batches(split="train"):
540     token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
541 token_probas = token_count / token_count.sum()
542 entropy = -torch.xlogy(token_probas, token_probas).sum()
543 train_set_perplexity = math.exp(entropy)
544
545 ##############################
546
547 if args.learning_rate_schedule == "cos":
548     learning_rate_schedule = {}
549     for n_epoch in range(args.nb_epochs):
550         u = n_epoch / args.nb_epochs * math.pi
551         learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
552 else:
553     u = {
554         int(k): float(v)
555         for k, v in [
556             tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
557         ]
558     }
559
560     learning_rate_schedule = {}
561     learning_rate = args.learning_rate
562     for n_epoch in range(args.nb_epochs):
563         if n_epoch in u:
564             learning_rate = u[n_epoch]
565         learning_rate_schedule[n_epoch] = learning_rate
566
567 log_string(f"learning_rate_schedule {learning_rate_schedule}")
568
569 ##############################
570
571 if nb_epochs_finished >= args.nb_epochs:
572     n_epoch = nb_epochs_finished
573     train_perplexity = compute_perplexity(model, split="train")
574     test_perplexity = compute_perplexity(model, split="test")
575
576     log_string(
577         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
578     )
579
580     task.produce_results(n_epoch, model)
581
582     exit(0)
583
584 ##############################
585
586 for n_epoch in range(nb_epochs_finished, args.nb_epochs):
587     learning_rate = learning_rate_schedule[n_epoch]
588
589     log_string(f"learning_rate {learning_rate}")
590
591     if args.optim == "sgd":
592         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
593     elif args.optim == "adam":
594         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
595     elif args.optim == "adamw":
596         optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
597     else:
598         raise ValueError(f"{args.optim=}")
599
600     model.train()
601
602     nb_train_samples, acc_train_loss = 0, 0.0
603
604     for input in task.batches(split="train"):
605         input = input.to(device)
606         order = random_order(input, task.height * task.width)
607         input = shuffle(input, order)
608         output = model(mygpt.BracketedSequence(input), order=order).x
609         loss = F.cross_entropy(output.transpose(1, 2), input)
610         acc_train_loss += loss.item() * input.size(0)
611         nb_train_samples += input.size(0)
612
613         optimizer.zero_grad()
614         loss.backward()
615         optimizer.step()
616
617     train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
618     test_perplexity = compute_perplexity(model, split="test")
619
620     log_string(
621         f"perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}"
622     )
623
624     task.produce_results(n_epoch, model)
625
626     checkpoint = {
627         "nb_epochs_finished": n_epoch + 1,
628         "model_state": model.state_dict(),
629         "rng_state": torch.get_rng_state(),
630     }
631
632     if torch.cuda.is_available():
633         checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state()
634
635     checkpoint_name = os.path.join(args.result_dir, args.checkpoint_name)
636     torch.save(checkpoint, checkpoint_name)
637     log_string(f"saved checkpoint {checkpoint_name}")
638
639 ######################################################################
640
641 if args.oneshot:
642     oneshot(model, task)
643
644 ######################################################################