def create_fourier_basis(self):
self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
self.fourier_basis_inverse = self.fourier_basis.inverse()
-
- torchvision.utils.save_image(
- 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(),
- "fourier.png",
- nrow=28,
- )
-
y = self.train_input.float() @ self.fourier_basis.t()
self.fourier_range = 4
self.fourier_mu = y.mean(dim=0, keepdim=True)
self.train_input = self.fourier_encode(self.train_input)
self.test_input = self.fourier_encode(self.test_input)
- torchvision.utils.save_image(
- 1
- - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
- / 256,
- "check-train.png",
- nrow=16,
- )
- torchvision.utils.save_image(
- 1
- - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
- / 256,
- "check-test.png",
- nrow=16,
- )
- # exit(0)
-
- print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}")
-
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
+ if n_epoch == 0:
+ image_name = os.path.join(result_dir, "fourier.png")
+ torchvision.utils.save_image(
+ 0.5
+ - 0.5
+ * self.fourier_basis.reshape(-1, 1, 28, 28)
+ / self.fourier_basis.std(),
+ image_name,
+ nrow=28,
+ )
+
+ image_name = os.path.join(result_dir, "check-train.png")
+ torchvision.utils.save_image(
+ 1
+ - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+ / 256,
+ image_name,
+ nrow=16,
+ )
+
results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
ar_mask = torch.full_like(results, 1)
masked_inplace_autoregression(