- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
- result_descr = self.detensorize(result)
-
- np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
-
- acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
- acc_nb_results = len(result_descr)
-
- nb_requested_properties = sum(acc_nb_requested_properties)
- nb_missing_properties = sum(acc_nb_missing_properties)
-
- prefix = "demo_"
- logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
- logger(
- f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
- )
- logger(
- f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
- )
-
- img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
-
- if img.dim() == 5:
- if img.size(1) == 1:
- img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
- else:
- img = torch.cat(
- [
- torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
- for x in img
- ],
- 0,
- )
-
- image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(
- img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
- )
- logger(f"wrote {image_name}")
-
-
-######################################################################
-
-
-class MNIST(Task):
- def __init__(
- self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
- ):
- super().__init__()
-
- self.nb_train_samples = (nb_train_samples,)
- self.nb_test_samples = (nb_test_samples,)
- self.batch_size = batch_size
- self.device = device
- data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
- self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
- data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
- self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return 256
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
- ar_mask = torch.full_like(results, 1)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- results,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
- image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(
- 1 - results.reshape(-1, 1, 28, 28) / 255.0,
- image_name,
- nrow=16,
- pad_value=0.8,
- )
- logger(f"wrote {image_name}")
-
-
-######################################################################
-
-import maze
-
-
-class Maze(Task):
- def map2seq(self, *m):
- return torch.cat([x.flatten(1) for x in m], 1)
-
- def seq2map(self, s):
- s = s.reshape(s.size(0), -1, self.height, self.width)
- return (s[:, k] for k in range(s.size(1)))
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- nb_walls,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.height = height
- self.width = width
- self.device = device
-
- train_mazes, train_paths, _ = maze.create_maze_data(
- nb_train_samples,
- height=height,
- width=width,
- nb_walls=nb_walls,
- progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
- )
- self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
-
- test_mazes, test_paths, _ = maze.create_maze_data(
- nb_test_samples,
- height=height,
- width=width,
- nb_walls=nb_walls,
- progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
- )
- self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def compute_error(
- self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
- ):
- nb_total, nb_correct = 0, 0
- count = torch.zeros(
- self.width * self.height,
- self.width * self.height,
- device=self.device,
- dtype=torch.int64,
- )
-
- for input in self.batches(split, nb_to_use):
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
- ar_mask[:, self.height * self.width :] = 1
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
- mazes, paths = self.seq2map(result)
- path_correctness = maze.path_correctness(mazes, paths)
- nb_correct += path_correctness.long().sum()
- nb_total += mazes.size(0)
-
- optimal_path_lengths = (
- (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- predicted_path_lengths = (
- (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- optimal_path_lengths = optimal_path_lengths[path_correctness]
- predicted_path_lengths = predicted_path_lengths[path_correctness]
- count[optimal_path_lengths, predicted_path_lengths] += 1
-
- if count.max() == 0:
- count = None
- else:
- count = count[
- : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
- ]
-
- return nb_total, nb_correct, count
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- train_nb_total, train_nb_correct, count = self.compute_error(
- model,
- "train",
- nb_to_use=1000,
- deterministic_synthesis=deterministic_synthesis,
- )
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct, count = self.compute_error(
- model,
- "test",
- nb_to_use=1000,