5 import torch, torchvision
8 from torch.nn import functional as F
10 ######################################################################
13 def masked_inplace_autoregression(
18 deterministic_synthesis,
19 forbidden_tokens=None,
20 progress_bar_desc="autoregression",
21 device=torch.device("cpu"),
23 assert input.size() == ar_mask.size()
25 batches = zip(input.split(batch_size), ar_mask.split(batch_size))
27 if progress_bar_desc is not None:
31 desc=progress_bar_desc,
32 # total=input.size(0) // batch_size,
35 with torch.autograd.no_grad():
39 for input, ar_mask in batches:
40 model.masked_inplace_autoregression(
41 input, ar_mask, forbidden_tokens, deterministic_synthesis
47 ######################################################################
51 def batches(self, split="train"):
54 def vocabulary_size(self):
58 self, n_epoch, model, result_dir, logger, deterministic_synthesis
63 ######################################################################
67 def generate_sequences(self, nb):
70 def log_performance(self, sequences, logger):
74 class ProblemByheart(Problem):
76 nb_seq, len_prompt, len_result = 100, 5, 5
77 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
78 self.seq[:, len_prompt] = 10
80 def generate_sequences(self, nb):
81 sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
82 ar_mask = (sequences == 10).long()
83 ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
84 return sequences, ar_mask
86 # problems = [ProblemByheart()]
87 # nb_common_codes = 100
89 # def generate_sequences(nb_samples):
90 # problem_indexes = torch.randint(len(problems), (nb_samples,))
91 # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
92 # print(f"{nb_samples_per_problem}")
94 # for nb, p in zip(nb_samples_per_problem, problems):
95 # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
98 # for strain, stest in zip(train_seq, test_seq):
99 # s = torch.cat((strain, stest), 0)
110 device=torch.device("cpu"),
115 self.batch_size = batch_size
118 self.train_input, self.train_ar_mask = problem.generate_sequences(
121 self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
123 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
125 # A bit of paranoia never hurts
127 self.nb_codes <= max_nb_codes
128 and self.train_input.min() >= 0
129 and self.test_input.min() >= 0
130 and tuple(self.train_ar_mask.unique()) == (0, 1)
131 and tuple(self.test_ar_mask.unique()) == (0, 1)
134 def batches(self, split="train", nb_to_use=-1, desc=None):
135 assert split in {"train", "test"}
136 input = self.train_input if split == "train" else self.test_input
138 input = input[:nb_to_use]
140 desc = f"epoch-{split}"
141 for batch in tqdm.tqdm(
142 input.split(self.batch_size), dynamic_ncols=True, desc=desc
146 def vocabulary_size(self):
150 self, n_epoch, model, result_dir, logger, deterministic_synthesis
152 def compute_accuracy(input, ar_mask):
153 result = input.clone() * (1 - ar_mask)
154 masked_inplace_autoregression(
159 deterministic_synthesis,
160 progress_bar_desc=None,
164 nb_total = ar_mask.sum().item()
165 nb_correct = ((result == input).long() * ar_mask).sum().item()
167 return nb_total, nb_correct
169 train_nb_total, train_nb_correct = compute_accuracy(
170 self.train_input, self.train_ar_mask
174 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
177 test_nb_total, test_nb_correct = compute_accuracy(
178 self.test_input, self.test_ar_mask
182 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
186 ######################################################################
191 class PicoCLVR(Task):
192 # Make a tensor from a list of strings
193 def tensorize(self, descr):
194 token_descr = [s.strip().split(" ") for s in descr]
195 l = max([len(s) for s in token_descr])
196 token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
197 id_descr = [[self.token2id[u] for u in s] for s in token_descr]
198 return torch.tensor(id_descr, device=self.device)
200 # Make a list of strings from a tensor
201 def detensorize(self, x):
202 return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
204 # trim all the tensors in the tuple z to remove as much token from
205 # left and right in the first tensor. If z is a tuple, all its
206 # elements are trimed according to the triming for the first
207 def trim(self, z, token="<nul>"):
208 n = self.token2id[token]
211 i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
212 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
213 return tuple([t[:, a:b] for t in z])
215 i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
216 a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
219 ######################
230 device=torch.device("cpu"),
236 def generate_descr(nb, cache_suffix, pruner):
237 return picoclvr.generate(
247 self.batch_size = batch_size
249 self.pruner_train = pruner_train
250 self.pruner_eval = pruner_eval
252 if logger is not None:
254 f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
257 self.train_descr = generate_descr(
258 nb_train_samples, "train", pruner=self.pruner_train
260 self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
262 # Build the tokenizer
263 tokens = {"<nul>", "<img>"}
264 for d in [self.train_descr, self.test_descr]:
266 for t in s.strip().split(" "):
268 # make this set a sorted list to get the same tensors given
270 tokens = list(tokens)
272 self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
273 self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
274 self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
276 # Tokenize the train and test sets
277 self.train_input = self.tensorize(self.train_descr)
278 self.test_input = self.tensorize(self.test_descr)
280 def batches(self, split="train"):
281 assert split in {"train", "test"}
282 input = self.train_input if split == "train" else self.test_input
283 for batch in tqdm.tqdm(
284 input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
286 yield self.trim(batch)
288 def vocabulary_size(self):
289 return len(self.token2id)
291 def compute_missing_properties(
292 self, n_epoch, model, logger, deterministic_synthesis, pruner=None
294 acc_nb_requested_properties = []
295 acc_nb_missing_properties = []
298 for input in tqdm.tqdm(
299 self.test_input.split(self.batch_size),
301 desc=f"test-properties",
303 result = input.clone()
304 ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
305 result = (1 - ar_mask) * result + ar_mask * self.t_nul
306 masked_inplace_autoregression(
311 deterministic_synthesis,
312 progress_bar_desc=None,
316 result_descr = self.detensorize(result)
317 np = picoclvr.nb_properties(
323 nb_requested_properties, _, nb_missing_properties = zip(*np)
324 acc_nb_requested_properties += nb_requested_properties
325 acc_nb_missing_properties += nb_missing_properties
326 acc_nb_results += len(result_descr)
328 nb_requested_properties = sum(acc_nb_requested_properties)
329 nb_missing_properties = sum(acc_nb_missing_properties)
331 prefix = "" if pruner is None else "pruned_"
332 logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
334 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
337 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
340 ######################################################################
343 self, n_epoch, model, result_dir, logger, deterministic_synthesis
345 self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
347 if self.pruner_eval is not None:
348 self.compute_missing_properties(n_epoch, model, self.pruner_eval)
350 nb_tokens_to_generate = self.height * self.width + 3
355 for primer_descr in [
356 "red above green <sep> green top <sep> blue right of red",
357 "there is red <sep> there is yellow <sep> there is blue",
358 "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
359 "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
361 primer += [primer_descr + " <img>"] * nb_per_primer
363 result = self.tensorize(primer)
364 fill = result.new_full(
365 result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
367 result = torch.cat((result, fill), 1)
368 ar_mask = (result == self.t_nul).long()
369 masked_inplace_autoregression(
374 deterministic_synthesis,
377 result_descr = self.detensorize(result)
379 np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
381 acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
382 acc_nb_results = len(result_descr)
384 nb_requested_properties = sum(acc_nb_requested_properties)
385 nb_missing_properties = sum(acc_nb_missing_properties)
388 logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
390 f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
393 f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
396 img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
400 img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
404 torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
410 image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
411 torchvision.utils.save_image(
412 img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
414 logger(f"wrote {image_name}")
417 ######################################################################
422 self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
426 self.nb_train_samples = (nb_train_samples,)
427 self.nb_test_samples = (nb_test_samples,)
428 self.batch_size = batch_size
430 data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
431 self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
432 data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
433 self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
435 def batches(self, split="train", nb_to_use=-1, desc=None):
436 assert split in {"train", "test"}
437 input = self.train_input if split == "train" else self.test_input
439 input = input[:nb_to_use]
441 desc = f"epoch-{split}"
442 for batch in tqdm.tqdm(
443 input.split(self.batch_size), dynamic_ncols=True, desc=desc
447 def vocabulary_size(self):
451 self, n_epoch, model, result_dir, logger, deterministic_synthesis
453 results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
454 ar_mask = torch.full_like(results, 1)
455 masked_inplace_autoregression(
460 deterministic_synthesis,
463 image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
464 torchvision.utils.save_image(
465 1 - results.reshape(-1, 1, 28, 28) / 255.0,
470 logger(f"wrote {image_name}")
473 ######################################################################
479 def map2seq(self, *m):
480 return torch.cat([x.flatten(1) for x in m], 1)
482 def seq2map(self, s):
483 s = s.reshape(s.size(0), -1, self.height, self.width)
484 return (s[:, k] for k in range(s.size(1)))
494 device=torch.device("cpu"),
498 self.batch_size = batch_size
503 train_mazes, train_paths, _ = maze.create_maze_data(
508 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
510 self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
512 test_mazes, test_paths, _ = maze.create_maze_data(
517 progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
519 self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
521 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
523 def batches(self, split="train", nb_to_use=-1, desc=None):
524 assert split in {"train", "test"}
525 input = self.train_input if split == "train" else self.test_input
527 input = input[:nb_to_use]
529 desc = f"epoch-{split}"
530 for batch in tqdm.tqdm(
531 input.split(self.batch_size), dynamic_ncols=True, desc=desc
535 def vocabulary_size(self):
539 self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
541 nb_total, nb_correct = 0, 0
543 self.width * self.height,
544 self.width * self.height,
549 for input in self.batches(split, nb_to_use):
550 result = input.clone()
551 ar_mask = result.new_zeros(result.size())
552 ar_mask[:, self.height * self.width :] = 1
553 result *= 1 - ar_mask
554 masked_inplace_autoregression(
559 deterministic_synthesis,
560 progress_bar_desc=None,
563 mazes, paths = self.seq2map(result)
564 path_correctness = maze.path_correctness(mazes, paths)
565 nb_correct += path_correctness.long().sum()
566 nb_total += mazes.size(0)
568 optimal_path_lengths = (
569 (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
571 predicted_path_lengths = (
572 (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
574 optimal_path_lengths = optimal_path_lengths[path_correctness]
575 predicted_path_lengths = predicted_path_lengths[path_correctness]
576 count[optimal_path_lengths, predicted_path_lengths] += 1
582 : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
585 return nb_total, nb_correct, count
588 self, n_epoch, model, result_dir, logger, deterministic_synthesis
590 train_nb_total, train_nb_correct, count = self.compute_error(
594 deterministic_synthesis=deterministic_synthesis,
597 f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
600 test_nb_total, test_nb_correct, count = self.compute_error(
604 deterministic_synthesis=deterministic_synthesis,
607 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
610 if count is not None:
611 proportion_optimal = count.diagonal().sum().float() / count.sum()
612 logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
614 os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
616 for i in range(count.size(0)):
617 for j in range(count.size(1)):
618 eol = " " if j < count.size(1) - 1 else "\n"
619 f.write(f"{count[i,j]}{eol}")
621 input = self.test_input[:48]
622 result = input.clone()
623 ar_mask = result.new_zeros(result.size())
624 ar_mask[:, self.height * self.width :] = 1
625 result *= 1 - ar_mask
626 masked_inplace_autoregression(
631 deterministic_synthesis,
635 mazes, paths = self.seq2map(input)
636 _, predicted_paths = self.seq2map(result)
638 filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
643 predicted_paths=predicted_paths,
644 path_correct=maze.path_correctness(mazes, predicted_paths),
645 path_optimal=maze.path_optimality(paths, predicted_paths),
647 logger(f"wrote {filename}")
650 ######################################################################
667 device=torch.device("cpu"),
671 self.batch_size = batch_size
675 self.prompt_length = prompt_length
677 self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
686 self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
696 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
698 def batches(self, split="train", nb_to_use=-1, desc=None):
699 assert split in {"train", "test"}
700 input = self.train_input if split == "train" else self.test_input
702 input = input[:nb_to_use]
704 desc = f"epoch-{split}"
705 for batch in tqdm.tqdm(
706 input.split(self.batch_size), dynamic_ncols=True, desc=desc
710 def vocabulary_size(self):
714 self, n_epoch, model, result_dir, logger, deterministic_synthesis
716 def compute_nb_correct(input, prior_visits):
717 result = input.clone()
718 i = torch.arange(result.size(1), device=result.device)[None, :]
720 torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
724 result *= 1 - ar_mask
726 masked_inplace_autoregression(
731 deterministic_synthesis,
735 nb_total = ((prior_visits > 0) * ar_mask).sum()
737 nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
739 return nb_total, nb_correct
741 test_nb_total, test_nb_correct = compute_nb_correct(
742 self.test_input[:1000], self.test_prior_visits[:1000]
746 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
750 ######################################################################
766 fraction_values_for_train=None,
767 device=torch.device("cpu"),
771 self.batch_size = batch_size
772 self.nb_steps = nb_steps
773 self.nb_stacks = nb_stacks
774 self.nb_digits = nb_digits
777 if fraction_values_for_train is None:
778 values_for_train = None
779 values_for_test = None
781 all = torch.randperm(10**nb_digits)
782 nb_for_train = int(all.size(0) * fraction_values_for_train)
783 values_for_train = all[:nb_for_train]
784 values_for_test = all[nb_for_train:]
786 self.train_input, self.train_stack_counts = stack.generate_sequences(
795 self.test_input, self.test_stack_counts = stack.generate_sequences(
804 i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
805 counts = self.test_stack_counts.flatten()[i.flatten()]
806 counts = F.one_hot(counts).sum(0)
807 logger(f"test_pop_stack_counts {counts}")
809 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
811 def batches(self, split="train", nb_to_use=-1, desc=None):
812 assert split in {"train", "test"}
813 input = self.train_input if split == "train" else self.test_input
815 input = input[:nb_to_use]
817 desc = f"epoch-{split}"
818 for batch in tqdm.tqdm(
819 input.split(self.batch_size), dynamic_ncols=True, desc=desc
823 def vocabulary_size(self):
827 self, n_epoch, model, result_dir, logger, deterministic_synthesis
829 def compute_nb_correct(input):
830 result = input.clone()
831 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
832 ar_mask = (result != input).long()
833 masked_inplace_autoregression(
838 deterministic_synthesis,
842 errors = ((result != input).long() * ar_mask).reshape(
843 -1, 1 + self.nb_digits
845 ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
847 nb_total = ar_mask.max(1).values.sum()
848 nb_correct = nb_total - errors.max(1).values.sum()
850 return nb_total, nb_correct
852 test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
855 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
858 ##############################################################
859 # Log a few generated sequences
860 input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
861 result = input.clone()
862 stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
863 ar_mask = (result != input).long()
865 # for n in range(result.size(0)):
867 # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
870 masked_inplace_autoregression(
875 deterministic_synthesis,
879 for n in range(result.size(0)):
881 f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
883 ##############################################################
886 ######################################################################
893 def tensorize(self, sequences):
894 len_max = max([len(x) for x in sequences])
899 [self.char2id[c] for c in s + "#" * (len_max - len(s))]
916 device=torch.device("cpu"),
920 self.batch_size = batch_size
923 train_sequences = expr.generate_sequences(
925 nb_variables=nb_variables,
926 length=sequence_length,
927 operand_max=operand_max,
928 result_max=result_max,
931 test_sequences = expr.generate_sequences(
933 nb_variables=nb_variables,
934 length=sequence_length,
935 operand_max=operand_max,
936 result_max=result_max,
939 symbols = list(set("#" + "".join(train_sequences + test_sequences)))
942 self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
943 self.id2char = dict([(n, c) for c, n in self.char2id.items()])
945 self.filler, self.space = self.char2id["#"], self.char2id[" "]
947 self.train_input = self.tensorize(train_sequences)
948 self.test_input = self.tensorize(test_sequences)
950 self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
952 def batches(self, split="train", nb_to_use=-1, desc=None):
953 assert split in {"train", "test"}
954 input = self.train_input if split == "train" else self.test_input
956 input = input[:nb_to_use]
958 desc = f"epoch-{split}"
959 for batch in tqdm.tqdm(
960 input.split(self.batch_size), dynamic_ncols=True, desc=desc
962 last = (batch != self.filler).max(0).values.nonzero().max() + 3
963 batch = batch[:, :last]
966 def vocabulary_size(self):
969 def seq2str(self, s):
970 return "".join([self.id2char[k.item()] for k in s])
978 deterministic_synthesis,
981 def compute_nb_correct(input):
982 result = input.clone()
983 s = (result == self.space).long()
984 ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
985 result = (1 - ar_mask) * result + ar_mask * self.filler
986 masked_inplace_autoregression(
991 deterministic_synthesis,
995 nb_total = input.size(0)
996 nb_correct = (input == result).long().min(1).values.sum()
998 #######################################################################
999 # Comput predicted vs. true variable values
1001 nb_delta = torch.zeros(5, dtype=torch.int64)
1004 values_input = expr.extract_results([self.seq2str(s) for s in input])
1005 values_result = expr.extract_results([self.seq2str(s) for s in result])
1007 filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
1009 with open(filename, "w") as f:
1010 for i, r in zip(values_input, values_result):
1011 for n, vi in i.items():
1013 f.write(f"{vi} {-1 if vr is None else vr}\n")
1015 if vr is None or vr < 0:
1019 if d >= nb_delta.size(0):
1024 ######################################################################
1026 return nb_total, nb_correct, nb_delta, nb_missed
1033 ) = compute_nb_correct(self.test_input[:10000])
1036 f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
1039 nb_total = test_nb_delta.sum() + test_nb_missed
1040 for d in range(test_nb_delta.size(0)):
1042 f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
1045 f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
1048 ##############################################################
1049 # Log a few generated sequences
1050 if input_file is None:
1051 input = self.test_input[:10]
1053 with open(input_file, "r") as f:
1054 sequences = [e.strip() for e in f.readlines()]
1055 sequences = [s + " " + "#" * 50 for s in sequences]
1056 input = self.tensorize(sequences)
1058 result = input.clone()
1059 s = (result == self.space).long()
1060 ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
1061 result = (1 - ar_mask) * result + ar_mask * self.filler
1063 for n in range(result.size(0)):
1064 logger(f"test_before {self.seq2str(result[n])}")
1066 masked_inplace_autoregression(
1071 deterministic_synthesis,
1075 correct = (1 - ar_mask) * self.space + ar_mask * input
1076 for n in range(result.size(0)):
1077 comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
1078 logger(f"test_after {self.seq2str(result[n])} {comment}")
1079 logger(f"truth {self.seq2str(correct[n])}")
1080 ##############################################################
1083 ######################################################################
1096 device=torch.device("cpu"),
1097 device_storage=torch.device("cpu"),
1101 self.batch_size = batch_size
1102 self.device = device
1111 ) = world.create_data_and_processors(
1116 nb_epochs=vqae_nb_epochs,
1119 device_storage=device_storage,
1122 print(f"{train_action_seq.size()=}")
1124 train_frame_seq = self.frame2seq(train_frames).to(device_storage)
1125 test_frame_seq = self.frame2seq(test_frames).to(device_storage)
1127 nb_frame_codes = max(train_frame_seq.max(), test_frame_seq.max()) + 1
1128 nb_action_codes = max(train_action_seq.max(), test_action_seq.max()) + 1
1130 self.len_frame_seq = train_frame_seq.size(1)
1131 self.len_action_seq = train_action_seq.size(1)
1132 self.nb_codes = nb_frame_codes + nb_action_codes
1134 train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
1135 print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
1136 train_action_seq += nb_frame_codes
1137 self.train_input = torch.cat(
1138 (train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
1141 test_frame_seq = test_frame_seq.reshape(test_frame_seq.size(0) // 2, 2, -1)
1142 test_action_seq += nb_frame_codes
1143 self.test_input = torch.cat(
1144 (test_frame_seq[:, 0, :], test_action_seq, test_frame_seq[:, 1, :]), 1
1147 def batches(self, split="train", nb_to_use=-1, desc=None):
1148 assert split in {"train", "test"}
1149 input = self.train_input if split == "train" else self.test_input
1151 input = input[:nb_to_use]
1153 desc = f"epoch-{split}"
1154 for batch in tqdm.tqdm(
1155 input.split(self.batch_size), dynamic_ncols=True, desc=desc
1157 yield batch.to(self.device)
1159 def vocabulary_size(self):
1160 return self.nb_codes
1162 def produce_results(
1163 self, n_epoch, model, result_dir, logger, deterministic_synthesis
1166 2 * self.len_frame_seq + self.len_action_seq, device=self.device
1169 input = self.test_input[:64].to(self.device)
1170 result = input.clone()
1173 (k >= self.len_frame_seq + self.len_action_seq).long().expand_as(result)
1175 result *= 1 - ar_mask
1177 masked_inplace_autoregression(
1182 deterministic_synthesis,
1186 seq_start = input[:, : self.len_frame_seq]
1187 seq_end = input[:, self.len_frame_seq + self.len_action_seq :]
1188 seq_predicted = result[:, self.len_frame_seq + self.len_action_seq :]
1191 (seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
1193 result = result.reshape(-1, result.size(-1))
1194 print(f"{result.size()=}")
1196 frames = self.seq2frame(result)
1197 image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")
1198 torchvision.utils.save_image(
1199 frames.float() / (world.Box.nb_rgb_levels - 1),
1205 logger(f"wrote {image_name}")
1208 ######################################################################